High-performance FITS I/O for PyTorch with native pytorch-frame integration.
torchfits is a hybrid C++/Python library designed for maximum performance in astronomical data loading pipelines. It provides:
- Zero-copy tensor operations with direct cfitsio integration
- Batch coordinate transformations using wcslib with OpenMP parallelization
- Native pytorch-frame support for tabular data with lazy query building
- Multi-level caching (L1 memory + L2 disk) for remote files
- Streaming datasets for large-scale machine learning workflows
- C++ engine bypasses Python GIL for I/O operations
- SIMD-optimized data type conversions
- Tile-aware reading for compressed images
- Memory pools for efficient allocation
- Aggressive buffering with adaptive buffer sizes
See PERFORMANCE_OPTIMIZATIONS.md and COMPREHENSIVE_INVESTIGATION_REPORT.md for detailed technical performance analysis.
- Direct tensor creation from FITS data
- Seamless pytorch-frame TensorFrame support
- GPU-ready data loading with device placement
- PyTorch DataLoader compatibility
- Native HTTP/HTTPS/FTP protocol support via cfitsio
- Intelligent caching for remote files
- Streaming datasets for cloud-scale data
# Clone the repository
git clone https://github.com/sfabbro/torchfits.git
cd torchfits
# Set up development environment
pixi install
pixi run devpip install torchfitsimport torchfits
# Smart function - returns Tensor for images, TensorFrame for tables
data = torchfits.read("image.fits", device='cuda')
# Advanced multi-HDU operations
with torchfits.open("multi_hdu.fits") as hdul:
image = hdul[0].to_tensor(device='cuda')
table = hdul[1].materialize() # Returns TensorFrame# Lazy data access with slicing
with torchfits.open("large_image.fits") as hdul:
hdu = hdul[0]
# Get subset without loading full image
cutout = hdu.data[1000:2000, 1000:2000] # Returns Tensor
# Memory-efficient statistics
stats = hdu.stats() # Computed in C++
# WCS transformations
pixels = torch.tensor([[100, 200], [300, 400]])
world_coords = hdu.wcs.pixel_to_world(pixels)# Lazy, chainable query building (torch-frame native)
with torchfits.open("catalog.fits") as hdul:
table = hdul[1]
# Build query plan (no I/O yet)
query = (table
.select(['RA', 'DEC', 'MAG_G'])
.filter("MAG_G < 20 AND FLAG == 0")
.head(10000))
# Execute and get TensorFrame
df = query.materialize()
# Or stream in batches
for batch in query.iter_rows(batch_size=1000):
process_batch(batch)from torchfits import create_dataloader
# Create optimized DataLoader
file_paths = ["image1.fits", "image2.fits", ...]
dataloader = create_dataloader(
file_paths,
batch_size=32,
num_workers=4,
device='cuda'
)
# Training loop
for batch in dataloader:
# batch is already on GPU
loss = model(batch)
loss.backward()torchfits.read(path, hdu=0, device='cpu')- Smart read functiontorchfits.write(path, data, header=None)- Smart write functiontorchfits.open(path, mode='r')- Multi-HDU file access
HDUList- Container for multiple HDUs with context managementTensorHDU- Image/cube data with lazy loading and WCS supportTableHDU- Tabular data with torch-frame integrationWCS- Batch coordinate transformations
FITSDataset- Map-style dataset for random accessIterableFITSDataset- Streaming dataset for large-scale datacreate_dataloader()- Factory for optimized DataLoaders
torchfits is designed for maximum performance:
- 10-100x faster than astropy for large arrays
- Zero-copy tensor creation from FITS data
- Parallel I/O with OpenMP acceleration
- Optimized memory usage with shared buffers
# Set up development environment
pixi install
# Build C++ extension
pixi run build
# Run tests
pixi run test
# Run benchmarks
pixi run benchsrc/torchfits/
βββ __init__.py # Top-level API
βββ hdu.py # Core HDU classes
βββ wcs.py # WCS functionality
βββ datasets.py # PyTorch datasets
βββ dataloader.py # DataLoader factories
βββ cpp_src/ # C++ extension
βββ fits.cpp # FITS I/O engine
βββ wcs.cpp # WCS transformations
βββ bindings.cpp # nanobind interface
GPL-2 License. See LICENSE file for details.
Contributions welcome! Please see CONTRIBUTING.md for guidelines.