diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65e8c50..53c97c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pandas>=2.2.0 numpy>=1.26.0 --only-binary :all: + pip install --no-cache-dir --upgrade "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git" pip install -e .[ci] - name: Run unit tests diff --git a/pyproject.toml b/pyproject.toml index 13ecc52..6446a6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ requires-python = ">=3.10" # Core dependencies (CPU-only, minimal) dependencies = [ + # rfi_toolbox: Shared RFI utilities (io, preprocessing, evaluation, datasets) + "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git", "numpy>=1.26.0", "scipy>=1.10.0", "pandas>=2.2.0", @@ -33,7 +35,6 @@ dependencies = [ "tqdm>=4.65.0", "matplotlib>=3.7.0", "datasets>=2.10.0", - "patchify>=0.2.3", "scikit-image>=0.20.0", ] diff --git a/src/samrfi/data/__init__.py b/src/samrfi/data/__init__.py index a56245b..36b8635 100644 --- a/src/samrfi/data/__init__.py +++ b/src/samrfi/data/__init__.py @@ -1,23 +1,35 @@ """ Data module - MS loading, preprocessing, and dataset creation + +NOTE: Core data utilities (MSLoader, Preprocessor, TorchDataset, BatchWriter) +have been moved to rfi_toolbox for sharing across ML methods. +This module provides forward-compatibility imports. """ +from rfi_toolbox.datasets import BatchWriter, TorchDataset + +# Forward imports from rfi_toolbox (shared utilities) +from rfi_toolbox.io import MSLoader +from rfi_toolbox.preprocessing import GPUPreprocessor, Preprocessor + +# SAM2-specific modules (stay in samrfi) from .adaptive_patcher import AdaptivePatcher, check_ms_compatibility from .gpu_dataset import GPUBatchTransformDataset, GPUTransformDataset from .gpu_transforms import GPUTransforms, create_gpu_transforms from .hf_dataset_wrapper import HFDatasetWrapper -from .preprocessor import GPUPreprocessor, Preprocessor from .ram_dataset import RAMCachedDataset from .sam_dataset import BatchedDataset, SAMDataset -from .torch_dataset import BatchWriter, TorchDataset __all__ = [ + # Shared utilities (from rfi_toolbox) + "MSLoader", "Preprocessor", "GPUPreprocessor", - "SAMDataset", - "BatchedDataset", "TorchDataset", "BatchWriter", + # SAM2-specific + "SAMDataset", + "BatchedDataset", "HFDatasetWrapper", "AdaptivePatcher", "check_ms_compatibility", @@ -27,6 +39,3 @@ "GPUBatchTransformDataset", "RAMCachedDataset", ] - -# Note: MSLoader requires CASA and is not imported by default -# Use: from samrfi.data.ms_loader import MSLoader diff --git a/src/samrfi/data/preprocessor.py b/src/samrfi/data/preprocessor.py index 5fe43ea..f06fa37 100644 --- a/src/samrfi/data/preprocessor.py +++ b/src/samrfi/data/preprocessor.py @@ -9,7 +9,6 @@ import numpy as np import torch -from patchify import patchify from scipy import stats from samrfi.utils import logger @@ -17,6 +16,29 @@ from .torch_dataset import TorchDataset +def patchify(array, patch_shape, step): + """ + Extract patches from 2D array using torch.unfold (replaces patchify library). + + Args: + array: 2D numpy array (H, W) + patch_shape: Tuple (patch_h, patch_w) + step: Step size for patch extraction + + Returns: + 4D array (n_patches_h, n_patches_w, patch_h, patch_w) + """ + patch_h, patch_w = patch_shape + tensor = torch.from_numpy(array) + + # Use unfold to extract patches: (H, W) -> (n_h, n_w, patch_h, patch_w) + patches = tensor.unfold(0, patch_h, step).unfold(1, patch_w, step) + + # Rearrange to match patchify output format + patches = patches.contiguous().numpy() + return patches + + # Standalone functions for multiprocessing (must be picklable) def _patchify_single_waterfall(waterfall, patch_size): """ diff --git a/src/samrfi/data_generation/__init__.py b/src/samrfi/data_generation/__init__.py index 1dc7ed8..85102f6 100644 --- a/src/samrfi/data_generation/__init__.py +++ b/src/samrfi/data_generation/__init__.py @@ -1,8 +1,15 @@ -"""Data generation modules for SAM-RFI""" +"""Data generation modules for SAM-RFI -from .synthetic_generator import SyntheticDataGenerator +NOTE: SyntheticDataGenerator has been moved to rfi_toolbox for sharing across ML methods. +For backward compatibility, we provide both: +- rfi_toolbox.data_generation.SyntheticDataGenerator (recommended) +- samrfi.data_generation.synthetic_generator.SyntheticDataGenerator (deprecated, will be removed) +""" -__all__ = ["SyntheticDataGenerator"] +# Forward import from rfi_toolbox (recommended) +from rfi_toolbox.data_generation import SyntheticDataGenerator -# Note: MSDataGenerator requires CASA and is not imported by default -# Use: from samrfi.data_generation.ms_generator import MSDataGenerator +# SAM2-specific data generation +from .ms_generator import MSDataGenerator + +__all__ = ["SyntheticDataGenerator", "MSDataGenerator"] diff --git a/src/samrfi/evaluation/__init__.py b/src/samrfi/evaluation/__init__.py index f603061..e5c8de3 100644 --- a/src/samrfi/evaluation/__init__.py +++ b/src/samrfi/evaluation/__init__.py @@ -1,21 +1,27 @@ """ Evaluation metrics and validation tools for RFI segmentation + +NOTE: Core metrics (IoU, F1, Dice, FFI, statistics) have been moved to rfi_toolbox +for sharing across ML methods. This module provides forward-compatibility imports. """ -from .metrics import ( +# Forward imports from rfi_toolbox (shared metrics) +from rfi_toolbox.evaluation import ( + compute_calcquality, compute_dice, compute_f1, + compute_ffi, compute_iou, compute_precision, compute_recall, - evaluate_segmentation, -) -from .statistics import ( - compute_calcquality, - compute_ffi, compute_statistics, + evaluate_segmentation, print_statistics_comparison, ) +from rfi_toolbox.io import inject_synthetic_data + +# SAM2-specific evaluation (if any remain in local files) +# Currently all metrics are in rfi_toolbox __all__ = [ "compute_iou", @@ -28,12 +34,5 @@ "compute_ffi", "compute_calcquality", "print_statistics_comparison", + "inject_synthetic_data", ] - -# Optional CASA dependency for MS injection -try: - from .ms_injection import inject_synthetic_data - - __all__.append("inject_synthetic_data") -except ImportError: - pass # CASA not available diff --git a/src/samrfi/inference/predictor.py b/src/samrfi/inference/predictor.py index 884ba91..b45b326 100644 --- a/src/samrfi/inference/predictor.py +++ b/src/samrfi/inference/predictor.py @@ -12,7 +12,7 @@ from tqdm import tqdm from transformers import Sam2Model, Sam2Processor -from samrfi.data import AdaptivePatcher, Preprocessor, SAMDataset +from samrfi.data import AdaptivePatcher, MSLoader, Preprocessor, SAMDataset from samrfi.utils import logger from samrfi.utils.errors import CheckpointMismatchError @@ -554,8 +554,6 @@ def predict_ms( Returns: Predicted flags array (baselines, pols, channels, times) """ - from samrfi.data.ms_loader import MSLoader - logger.info(f"\n{'='*60}") logger.info("RFI Prediction - Single Pass") logger.info(f"{'='*60}") @@ -677,8 +675,6 @@ def predict_iterative( Returns: Cumulative flags from all iterations """ - from samrfi.data.ms_loader import MSLoader - logger.info(f"\n{'='*60}") logger.info(f"RFI Prediction - Iterative ({num_iterations} passes)") logger.info(f"{'='*60}")