👉 Code release for Retrieval-Augmented Anatomical Guidance for Text-to-CT Generation
Molino, D., Caruso, C. M., Soda, P., Guarrasi, V. (2026)
This repository extends the original Text2CT model with a retrieval-guided anatomical branch. Given a report embedding, the pipeline retrieves a semantically related case and uses its anatomical mask as structural guidance through a 3D ControlNetMaisi branch.
The release is organized around four stages:
- Train or reuse the original Text2CT backbone:
- CLIP3D text encoder
- 3D autoencoder
- text-conditioned latent diffusion UNet
- Precompute report embeddings with
scripts/save_embeddings_ctrate.py. - Build the retrieval bank with
scripts/build_rag_index.py. - Train and run the retrieval-augmented ControlNet with:
scripts/train_controlnet.pyscripts/infer_controlnet_RAG.py
Python 3.10+ is recommended.
pip install -r requirements.txtYou can download them from Hugging Face at Weights:
from huggingface_hub import snapshot_download
repo_id = "dmolino/RAGText2CT"
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir="your_local_path"
)Place checkpoints in models/.
autoencoder_epoch273.ptunet_rflow_200ep.ptCLIP3D_Finding_Impression_30ep.ptcontrolnet_rag_best.pt
This release assumes the CT-RATE-style JSON lists already point to image paths like:
dataset/train/...
dataset/valid/...
The repo expects:
dataset/train_data_volumes.jsondataset/validation_data_volumes.jsondataset/train_reports.csvdataset/validation_reports.csv- CT volumes under
dataset/ - segmentation masks under a parallel tree such as
dataset/ts_seg/ts_total/...
Default configs use data_base_dir: ".", so JSON paths remain valid as-is.
python scripts/save_embeddings_ctrate.py \
--train_json dataset/train_data_volumes.json \
--val_json dataset/validation_data_volumes.json \
--train_reports dataset/train_reports.csv \
--val_reports dataset/validation_reports.csv \
--data_base_dir . \
--embedding_base_dir ./embeddings \
--clip_weights ./models/CLIP3D_Finding_Impression_30ep.pt \
--report_encoder_model xgem_3Dpython scripts/diff_model_create_training_data.py \
--model_def ./configs/config_rflow.json \
--model_config ./configs/config_diff_model.json \
--env_config ./configs/environment_diff_model_train.json \
--num_gpus 1 \
--index 0python scripts/build_rag_index.py \
--data_list dataset/train_data_volumes.json \
--embedding_base_dir ./embeddings \
--report_encoder_model xgem_3D \
--output_dir ./retrievalThis produces:
retrieval/impression_embeddings.npyretrieval/impression_paths.json
The paths file stores mask paths used as anatomical proxies during retrieval.
python scripts/train_controlnet.py \
--environment-file ./configs/environment_rag_controlnet_train.json \
--config-file ./configs/config_rag_rflow.json \
--training-config ./configs/config_rag_controlnet.json \
--gpus 1python scripts/infer_controlnet_RAG.py \
--environment-file ./configs/environment_rag_controlnet_eval.json \
--config-file ./configs/config_rag_rflow.json \
--training-config ./configs/config_rag_controlnet.json \
--gpus 1 \
--index 0Outputs are written under predictions/rag/.
For a minimal end-to-end smoke test, place one CT and one mask in the repo root as ct.nii.gz and mask.nii.gz, then run:
python scripts/rag_demo_single.py \
--report "Small right pleural effusion. Mild bibasal atelectatic changes. No focal lung consolidation." \
--ct ct.nii.gz \
--mask mask.nii.gz \
--weights-dir hf_ragtext2ct/models \
--output predictions/rag_demo_single.nii.gzThis demo encodes the report, uses the provided mask as the singleton retrieval target, and writes:
predictions/rag_demo_single.nii.gzpredictions/rag_demo_single.json
configs/config_rag_rflow.json: model definitions for autoencoder, UNet and ControlNet.configs/config_rag_controlnet.json: ControlNet train/infer hyperparameters.configs/environment_rag_controlnet_train.json: training paths.configs/environment_rag_controlnet_eval.json: evaluation paths and retrieval bank paths.scripts/build_rag_index.py: stacks per-case report embeddings and exports the retrieval mapping.scripts/train_controlnet.py: ControlNet training.scripts/infer_controlnet_RAG.py: retrieval-augmented generation.
- This repo is intentionally model-focused. Challenge-specific wrappers and
VLM3Dcontainer code are excluded. - The retrieval bank is currently file-based (
.npy+.json) and loaded into a FAISS inner-product index at inference time. - The original Text2CT pipeline is still available through the
diff_model_*scripts. - For HF downloads of the weights bundle, see
dmolino/RAGText2CT.
@article{Molino2026RAGText2CT,
title={Retrieval-Augmented Anatomical Guidance for Text-to-CT Generation},
author={Molino, Daniele and Caruso, Camillo Maria and Soda, Paolo and Guarrasi, Valerio},
year={2026},
journal={arXiv preprint arXiv:2603.08305}
}