From 02c3fb6b717456b0efb4aadca178442ffe24f86d Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Fri, 19 Jun 2026 17:08:38 -0700 Subject: [PATCH 1/4] =?UTF-8?q?Improve=20diskann-disk=20test=20coverage:?= =?UTF-8?q?=2094.4%=20=E2=86=92=2095.0%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 18 targeted unit tests covering previously untested error paths, control flow branches, and pure functions across 7 files. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../checkpoint_record_manager_with_file.rs | 71 +++++++++ .../src/build/chunking/continuation/utils.rs | 144 ++++++++++++++++++ diskann-disk/src/search/pq/pq_scratch.rs | 25 +++ .../src/search/provider/disk_sector_graph.rs | 44 ++++++ diskann-disk/src/storage/quant/generator.rs | 36 +++++ diskann-disk/src/utils/kmeans.rs | 67 ++++++++ diskann-disk/src/utils/partition.rs | 52 +++++++ 7 files changed, 439 insertions(+) diff --git a/diskann-disk/src/build/chunking/checkpoint/checkpoint_record_manager_with_file.rs b/diskann-disk/src/build/chunking/checkpoint/checkpoint_record_manager_with_file.rs index 774788d3b..eb69a8840 100644 --- a/diskann-disk/src/build/chunking/checkpoint/checkpoint_record_manager_with_file.rs +++ b/diskann-disk/src/build/chunking/checkpoint/checkpoint_record_manager_with_file.rs @@ -104,6 +104,77 @@ mod tests { } } + #[test] + fn test_new_checkpoint_creates_fresh_record() -> ANNResult<()> { + let temp_dir = tempdir()?; + let index_prefix = temp_dir + .path() + .join("fresh_index") + .to_str() + .unwrap() + .to_string(); + // Two managers with the same prefix+identifier should see the same checkpoint state + let manager_a = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42); + let manager_b = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42); + assert_eq!( + manager_a.get_resumption_point(WorkStage::Start)?, + manager_b.get_resumption_point(WorkStage::Start)? + ); + // A different identifier should be independent + let manager_c = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 99); + assert!(!manager_c.has_completed()?); + Ok(()) + } + + #[test] + fn test_has_completed_false_when_no_file() -> ANNResult<()> { + let temp_dir = tempdir()?; + let index_prefix = temp_dir + .path() + .join("nonexistent_index") + .to_str() + .unwrap() + .to_string(); + let manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 999); + assert!(!manager.has_completed()?); + Ok(()) + } + + #[test] + fn test_mark_as_invalid() -> ANNResult<()> { + let temp_dir = tempdir()?; + let index_prefix = temp_dir + .path() + .join("test_invalid") + .to_str() + .unwrap() + .to_string(); + let identifier = 77; + + let mut manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier); + // Advance to a later stage with some progress + manager.update(Progress::Completed, WorkStage::QuantizeFPV)?; + manager.update(Progress::Processed(42), WorkStage::InMemIndexBuild)?; + + // Verify we can resume from progress=42 + let manager2 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier); + assert_eq!( + manager2.get_resumption_point(WorkStage::QuantizeFPV)?, + Some(42) + ); + + // Mark as invalid - progress resets to 0 (is_valid=false => progress read as 0) + let mut manager3 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier); + manager3.mark_as_invalid()?; + assert_eq!( + manager3.get_resumption_point(WorkStage::QuantizeFPV)?, + Some(0) + ); + + clean_checkpoint_file(&index_prefix, identifier); + Ok(()) + } + #[test] fn test_checkpoint_manager_interruption_and_resumption() -> ANNResult<()> { let temp_dir = tempdir()?; diff --git a/diskann-disk/src/build/chunking/continuation/utils.rs b/diskann-disk/src/build/chunking/continuation/utils.rs index 0d92e416e..53b664956 100644 --- a/diskann-disk/src/build/chunking/continuation/utils.rs +++ b/diskann-disk/src/build/chunking/continuation/utils.rs @@ -150,6 +150,150 @@ mod tests { } } + /// A tracker that returns Stop after `stop_after` Continue grants. + #[derive(Clone)] + struct StopAfterTracker { + count: std::sync::Arc>, + stop_after: usize, + } + + impl ContinuationTrackerTrait for StopAfterTracker { + fn get_continuation_grant(&self) -> ContinuationGrant { + let mut count = self.count.lock().unwrap(); + if *count >= self.stop_after { + ContinuationGrant::Stop + } else { + *count += 1; + ContinuationGrant::Continue + } + } + } + + #[test] + fn test_process_while_resource_is_available_stops_early() { + let tracker = StopAfterTracker { + count: std::sync::Arc::new(std::sync::Mutex::new(0)), + stop_after: 3, + }; + let items = vec![10, 20, 30, 40, 50]; + let mut processed = Vec::new(); + + let result = process_while_resource_is_available( + |item| { + processed.push(item); + Ok::<(), TestError>(()) + }, + items.into_iter(), + Box::new(tracker), + ); + + assert!(result.is_ok()); + match result.unwrap() { + Progress::Processed(idx) => { + assert_eq!(idx, 3); // stopped before processing item at index 3 + assert_eq!(processed, vec![10, 20, 30]); + } + _ => panic!("Expected Processed"), + } + } + + /// A tracker that yields once (with a tiny duration), then continues. + #[derive(Clone)] + struct YieldOnceThenContinueTracker { + yielded: std::sync::Arc>, + } + + impl ContinuationTrackerTrait for YieldOnceThenContinueTracker { + fn get_continuation_grant(&self) -> ContinuationGrant { + let mut yielded = self.yielded.lock().unwrap(); + if !*yielded { + *yielded = true; + ContinuationGrant::Yield(std::time::Duration::from_millis(1)) + } else { + // After yielding once, always continue + ContinuationGrant::Continue + } + } + } + + #[test] + fn test_process_while_resource_is_available_yield_then_continue() { + let tracker = YieldOnceThenContinueTracker { + yielded: std::sync::Arc::new(std::sync::Mutex::new(false)), + }; + let items = vec![1, 2]; + let mut processed = Vec::new(); + + let result = process_while_resource_is_available( + |item| { + processed.push(item); + Ok::<(), TestError>(()) + }, + items.into_iter(), + Box::new(tracker), + ); + + assert!(result.is_ok()); + // After yielding, it should have continued and processed all items + match result.unwrap() { + Progress::Completed => assert_eq!(processed, vec![1, 2]), + _ => panic!("Expected Completed"), + } + } + + #[test] + fn test_process_while_resource_is_available_action_error() { + let checker = Box::new(NaiveContinuationTracker::default()); + let items = vec![1, 2, 3]; + + let result = process_while_resource_is_available( + |item| { + if item == 2 { + Err(TestError) + } else { + Ok(()) + } + }, + items.into_iter(), + checker, + ); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_process_while_resource_is_available_async_stops_early() { + let tracker = StopAfterTracker { + count: std::sync::Arc::new(std::sync::Mutex::new(0)), + stop_after: 2, + }; + let items = vec![1, 2, 3, 4, 5]; + let processed = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())); + + let result = process_while_resource_is_available_async( + |item| { + let processed = processed.clone(); + async move { + processed.lock().await.push(item); + Ok::<(), TestError>(()) + } + }, + items.into_iter(), + Box::new(tracker), + ) + .await; + + assert!(result.is_ok()); + match result.unwrap() { + Progress::Processed(idx) => { + assert_eq!(idx, 2); + let processed = processed.lock().await; + assert_eq!(*processed, vec![1, 2]); + } + _ => panic!("Expected Processed"), + } + } + #[tokio::test] async fn test_process_while_resource_is_available_async_completes() { let checker = Box::new(NaiveContinuationTracker::default()); diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index 1df666e18..13cb0d8cd 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -128,4 +128,29 @@ mod tests { assert_eq!(pq_scratch.query_scratch[i], query[i]); }); } + + #[test] + fn test_pq_scratch_set_rejects_short_query() { + let dim = 16; + let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap(); + + // Query shorter than dim should fail + let short_query: Vec = (1..dim).map(|i| i as f32).collect(); // dim-1 elements + let err = pq_scratch.set(&short_query).unwrap_err(); + assert!(err.to_string().contains("expected query of length")); + } + + #[test] + fn test_pq_scratch_set_accepts_oversized_query() { + let dim = 8; + let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap(); + + // Query longer than dim should succeed (only first `dim` elements used) + let long_query: Vec = (1..=dim + 10).map(|i| i as f32).collect(); + pq_scratch.set(&long_query).unwrap(); + + for (i, &val) in long_query.iter().enumerate().take(dim) { + assert_eq!(pq_scratch.query_scratch[i], val); + } + } } diff --git a/diskann-disk/src/search/provider/disk_sector_graph.rs b/diskann-disk/src/search/provider/disk_sector_graph.rs index 117b492e5..9724a74f4 100644 --- a/diskann-disk/src/search/provider/disk_sector_graph.rs +++ b/diskann-disk/src/search/provider/disk_sector_graph.rs @@ -373,4 +373,48 @@ mod disk_sector_graph_test { let data = &graph; assert_eq!(data.len(), 512); } + + #[test] + fn test_reconfigure_grows_buffer() { + let reader = AlignedFileReaderFactory::new(test_index_path()) + .build() + .unwrap(); + let mut graph = test_initialize_disk_sector_graph(2, 1, reader); + assert_eq!(graph.max_n_batch_sector_read, 4); + + // Reconfigure to larger batch — buffer must grow beyond initial 512 bytes + graph.reconfigure(16).unwrap(); + assert_eq!(graph.max_n_batch_sector_read, 16); + assert_eq!(graph.sectors_data.len(), 16 * 64); + } + + #[test] + fn test_reconfigure_noop_for_smaller_size() { + let reader = AlignedFileReaderFactory::new(test_index_path()) + .build() + .unwrap(); + let mut graph = test_initialize_disk_sector_graph(2, 1, reader); + let original_len = graph.sectors_data.len(); + + // Reconfigure with same or smaller size should be a no-op + graph.reconfigure(4).unwrap(); + assert_eq!(graph.max_n_batch_sector_read, 4); + assert_eq!(graph.sectors_data.len(), original_len); + + graph.reconfigure(2).unwrap(); + assert_eq!(graph.max_n_batch_sector_read, 4); + assert_eq!(graph.sectors_data.len(), original_len); + } + + #[test] + fn test_new_disk_sector_graph_zero_block_size_defaults() { + let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256); + // block_size = 0 should fall back to DEFAULT_DISK_SECTOR_LEN regardless of version + let header = GraphHeader::new(metadata, 0, GraphLayoutVersion::new(1, 0)); + let reader = AlignedFileReaderFactory::new(test_index_path()) + .build() + .unwrap(); + let graph = DiskSectorGraph::new(reader, &header, 2).unwrap(); + assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN); + } } diff --git a/diskann-disk/src/storage/quant/generator.rs b/diskann-disk/src/storage/quant/generator.rs index c14bf998a..3d830f515 100644 --- a/diskann-disk/src/storage/quant/generator.rs +++ b/diskann-disk/src/storage/quant/generator.rs @@ -598,4 +598,40 @@ mod generator_tests { Ok(()) } + + #[test] + fn test_validate_params_missing_compressed_file() -> ANNResult<()> { + let storage_provider = VirtualStorageProvider::new_memory(); + storage_provider + .filesystem() + .create_dir("/test_data") + .expect("Could not create test directory"); + + let data_path = "/test_data/data.bin"; + let compressed_path = "/test_data/compressed.bin"; + let num_points = 100; + let dim = 8; + let output_dim = 4u32; + + // Create source data + let data = create_test_data(num_points, dim); + let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap(); + write_bin(view, &mut storage_provider.create_for_write(data_path)?)?; + + // Don't create compressed file but set offset > 0 + let context = GeneratorContext::new(10, compressed_path.to_string()); + let generator = QuantDataGenerator::::new( + data_path.to_string(), + context, + &output_dim, + ) + .unwrap(); + + let err = generator.validate_params(num_points, &storage_provider); + assert!(err.is_err()); + let err_msg = format!("{:?}", err.unwrap_err()); + assert!(err_msg.contains("expected compressed file")); + + Ok(()) + } } diff --git a/diskann-disk/src/utils/kmeans.rs b/diskann-disk/src/utils/kmeans.rs index 0c7e895ec..7ee6ba384 100644 --- a/diskann-disk/src/utils/kmeans.rs +++ b/diskann-disk/src/utils/kmeans.rs @@ -992,4 +992,71 @@ mod kmeans_test { k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),pool.as_ref()).unwrap(); } } + + #[test] + fn k_means_clustering_produces_valid_output() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 5; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = vec![0.0; num_centers * dim]; + let pool = create_thread_pool_for_test(); + + let (closest_docs, closest_center, residual) = k_means_clustering( + &data, + num_points, + dim, + &mut centers, + num_centers, + max_reps, + &mut create_rnd_in_tests(), + &mut false, + pool.as_ref(), + ) + .unwrap(); + + // Check shapes + assert_eq!(closest_docs.len(), num_centers); + assert_eq!(closest_center.len(), num_points); + assert!(residual >= 0.0); + + // Every point should be assigned to exactly one cluster + let total_assigned: usize = closest_docs.iter().map(|d| d.len()).sum(); + assert_eq!(total_assigned, num_points); + + // closest_center values should be in [0, num_centers) + for &cc in &closest_center { + assert!((cc as usize) < num_centers); + } + } + + #[test] + fn k_means_clustering_returns_err_when_canceled() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 5; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = vec![0.0; num_centers * dim]; + let pool = create_thread_pool_for_test(); + + let err = k_means_clustering( + &data, + num_points, + dim, + &mut centers, + num_centers, + max_reps, + &mut create_rnd_in_tests(), + &mut true, // canceled + pool.as_ref(), + ) + .unwrap_err(); + + assert_eq!(err.kind(), ANNErrorKind::PQError); + assert!(err.to_string().contains("Cancellation requested by caller")); + } } diff --git a/diskann-disk/src/utils/partition.rs b/diskann-disk/src/utils/partition.rs index 0e9a05b0f..a8556f5d5 100644 --- a/diskann-disk/src/utils/partition.rs +++ b/diskann-disk/src/utils/partition.rs @@ -545,6 +545,58 @@ mod partition_test { buffer } + #[test] + fn test_estimate_initial_partition_count_minimum_clamp() { + // When total RAM fits well within budget, should clamp to minimum of 3 + let count = estimate_initial_partition_count( + 100, // total_points + 10, // dimension + 1, // k_base + 1_000_000.0, // ram_budget_in_bytes + &|n, _d| n as f64 * 100.0, // total_ram = 100 * 100 = 10_000 << 1_000_000 => clamp to 3 + ); + assert_eq!(count, 3); + } + + #[test] + fn test_estimate_initial_partition_count_odd_rounding() { + // Even partition count should be bumped to odd + let count = estimate_initial_partition_count( + 1000, + 128, + 1, + 1000.0, // budget + &|n, _d| n as f64 * 4.0, // total_ram = 4000, ratio = 4 => ceil = 4 (even) => 5 + ); + assert_eq!(count, 5); + } + + #[test] + fn test_estimate_initial_partition_count_large_ratio() { + // Odd result that is >= 3 should be returned as-is + let count = estimate_initial_partition_count( + 1000, + 128, + 1, + 1000.0, // budget + &|n, _d| n as f64 * 7.0, // total_ram = 7000, ratio = 7 => odd, >= 3 + ); + assert_eq!(count, 7); + } + + #[test] + fn test_estimate_initial_partition_count_k_base_multiplier() { + // k_base multiplies total_points in the estimator call + let count = estimate_initial_partition_count( + 100, + 10, + 3, // k_base + 100.0, // budget + &|n, _d| n as f64 * 1.0, // n = total_points * k_base = 300, total_ram = 300, ratio = 3 + ); + assert_eq!(count, 3); + } + #[test] fn test_partition_with_ram_budget() -> ANNResult<()> { let storage_provider = VirtualStorageProvider::new_overlay(test_data_root()); From 94cfbf4031bd1d064dfb6ab7948f5b7a3491adbd Mon Sep 17 00:00:00 2001 From: Alex Razumov Date: Sun, 21 Jun 2026 19:33:05 -0700 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-disk/src/build/chunking/continuation/utils.rs | 2 +- diskann-disk/src/search/pq/pq_scratch.rs | 2 +- diskann-disk/src/storage/quant/generator.rs | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/diskann-disk/src/build/chunking/continuation/utils.rs b/diskann-disk/src/build/chunking/continuation/utils.rs index 53b664956..609005988 100644 --- a/diskann-disk/src/build/chunking/continuation/utils.rs +++ b/diskann-disk/src/build/chunking/continuation/utils.rs @@ -208,7 +208,7 @@ mod tests { let mut yielded = self.yielded.lock().unwrap(); if !*yielded { *yielded = true; - ContinuationGrant::Yield(std::time::Duration::from_millis(1)) + ContinuationGrant::Yield(std::time::Duration::ZERO) } else { // After yielding once, always continue ContinuationGrant::Continue diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index 13cb0d8cd..cae2df83a 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -137,8 +137,8 @@ mod tests { // Query shorter than dim should fail let short_query: Vec = (1..dim).map(|i| i as f32).collect(); // dim-1 elements let err = pq_scratch.set(&short_query).unwrap_err(); + assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError); assert!(err.to_string().contains("expected query of length")); - } #[test] fn test_pq_scratch_set_accepts_oversized_query() { diff --git a/diskann-disk/src/storage/quant/generator.rs b/diskann-disk/src/storage/quant/generator.rs index 3d830f515..957e057d9 100644 --- a/diskann-disk/src/storage/quant/generator.rs +++ b/diskann-disk/src/storage/quant/generator.rs @@ -627,11 +627,11 @@ mod generator_tests { ) .unwrap(); - let err = generator.validate_params(num_points, &storage_provider); - assert!(err.is_err()); - let err_msg = format!("{:?}", err.unwrap_err()); - assert!(err_msg.contains("expected compressed file")); - + let err = generator + .validate_params(num_points, &storage_provider) + .unwrap_err(); + assert_eq!(err.kind(), diskann::ANNErrorKind::FileNotFoundError); + assert!(err.to_string().contains("expected compressed file")); Ok(()) } } From 6d26e5a62cc859fd9bfea24e1e6c157a5198a7c1 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 22 Jun 2026 11:03:37 -0700 Subject: [PATCH 3/4] Empty commit From 31bb404d03756d069ea2b31341b25029ce6eb314 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 22 Jun 2026 16:23:25 -0700 Subject: [PATCH 4/4] Fix missing closing brace in pq_scratch test --- diskann-disk/src/search/pq/pq_scratch.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index cae2df83a..6434dd8dd 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -139,6 +139,7 @@ mod tests { let err = pq_scratch.set(&short_query).unwrap_err(); assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError); assert!(err.to_string().contains("expected query of length")); + } #[test] fn test_pq_scratch_set_accepts_oversized_query() {