Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pandas>=2.2.0 numpy>=1.26.0 --only-binary :all:
pip install --no-cache-dir --upgrade "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git"
pip install -e .[ci]

- name: Run unit tests
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ requires-python = ">=3.10"

# Core dependencies (CPU-only, minimal)
dependencies = [
# rfi_toolbox: Shared RFI utilities (io, preprocessing, evaluation, datasets)
"rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git",
"numpy>=1.26.0",
"scipy>=1.10.0",
"pandas>=2.2.0",
Expand All @@ -33,7 +35,6 @@ dependencies = [
"tqdm>=4.65.0",
"matplotlib>=3.7.0",
"datasets>=2.10.0",
"patchify>=0.2.3",
"scikit-image>=0.20.0",
]

Expand Down
23 changes: 16 additions & 7 deletions src/samrfi/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
"""
Data module - MS loading, preprocessing, and dataset creation

NOTE: Core data utilities (MSLoader, Preprocessor, TorchDataset, BatchWriter)
have been moved to rfi_toolbox for sharing across ML methods.
This module provides forward-compatibility imports.
"""

from rfi_toolbox.datasets import BatchWriter, TorchDataset

# Forward imports from rfi_toolbox (shared utilities)
from rfi_toolbox.io import MSLoader
from rfi_toolbox.preprocessing import GPUPreprocessor, Preprocessor

# SAM2-specific modules (stay in samrfi)
from .adaptive_patcher import AdaptivePatcher, check_ms_compatibility
from .gpu_dataset import GPUBatchTransformDataset, GPUTransformDataset
from .gpu_transforms import GPUTransforms, create_gpu_transforms
from .hf_dataset_wrapper import HFDatasetWrapper
from .preprocessor import GPUPreprocessor, Preprocessor
from .ram_dataset import RAMCachedDataset
from .sam_dataset import BatchedDataset, SAMDataset
from .torch_dataset import BatchWriter, TorchDataset

__all__ = [
# Shared utilities (from rfi_toolbox)
"MSLoader",
"Preprocessor",
"GPUPreprocessor",
"SAMDataset",
"BatchedDataset",
"TorchDataset",
"BatchWriter",
# SAM2-specific
"SAMDataset",
"BatchedDataset",
"HFDatasetWrapper",
"AdaptivePatcher",
"check_ms_compatibility",
Expand All @@ -27,6 +39,3 @@
"GPUBatchTransformDataset",
"RAMCachedDataset",
]

# Note: MSLoader requires CASA and is not imported by default
# Use: from samrfi.data.ms_loader import MSLoader
24 changes: 23 additions & 1 deletion src/samrfi/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,36 @@

import numpy as np
import torch
from patchify import patchify
from scipy import stats

from samrfi.utils import logger

from .torch_dataset import TorchDataset


def patchify(array, patch_shape, step):
"""
Extract patches from 2D array using torch.unfold (replaces patchify library).

Args:
array: 2D numpy array (H, W)
patch_shape: Tuple (patch_h, patch_w)
step: Step size for patch extraction

Returns:
4D array (n_patches_h, n_patches_w, patch_h, patch_w)
"""
patch_h, patch_w = patch_shape
tensor = torch.from_numpy(array)

# Use unfold to extract patches: (H, W) -> (n_h, n_w, patch_h, patch_w)
patches = tensor.unfold(0, patch_h, step).unfold(1, patch_w, step)

# Rearrange to match patchify output format
patches = patches.contiguous().numpy()
return patches


# Standalone functions for multiprocessing (must be picklable)
def _patchify_single_waterfall(waterfall, patch_size):
"""
Expand Down
17 changes: 12 additions & 5 deletions src/samrfi/data_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Data generation modules for SAM-RFI"""
"""Data generation modules for SAM-RFI

from .synthetic_generator import SyntheticDataGenerator
NOTE: SyntheticDataGenerator has been moved to rfi_toolbox for sharing across ML methods.
For backward compatibility, we provide both:
- rfi_toolbox.data_generation.SyntheticDataGenerator (recommended)
- samrfi.data_generation.synthetic_generator.SyntheticDataGenerator (deprecated, will be removed)
"""

__all__ = ["SyntheticDataGenerator"]
# Forward import from rfi_toolbox (recommended)
from rfi_toolbox.data_generation import SyntheticDataGenerator

# Note: MSDataGenerator requires CASA and is not imported by default
# Use: from samrfi.data_generation.ms_generator import MSDataGenerator
# SAM2-specific data generation
from .ms_generator import MSDataGenerator

__all__ = ["SyntheticDataGenerator", "MSDataGenerator"]
27 changes: 13 additions & 14 deletions src/samrfi/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"""
Evaluation metrics and validation tools for RFI segmentation

NOTE: Core metrics (IoU, F1, Dice, FFI, statistics) have been moved to rfi_toolbox
for sharing across ML methods. This module provides forward-compatibility imports.
"""

from .metrics import (
# Forward imports from rfi_toolbox (shared metrics)
from rfi_toolbox.evaluation import (
compute_calcquality,
compute_dice,
compute_f1,
compute_ffi,
compute_iou,
compute_precision,
compute_recall,
evaluate_segmentation,
)
from .statistics import (
compute_calcquality,
compute_ffi,
compute_statistics,
evaluate_segmentation,
print_statistics_comparison,
)
from rfi_toolbox.io import inject_synthetic_data

# SAM2-specific evaluation (if any remain in local files)
# Currently all metrics are in rfi_toolbox

__all__ = [
"compute_iou",
Expand All @@ -28,12 +34,5 @@
"compute_ffi",
"compute_calcquality",
"print_statistics_comparison",
"inject_synthetic_data",
]

# Optional CASA dependency for MS injection
try:
from .ms_injection import inject_synthetic_data

__all__.append("inject_synthetic_data")
except ImportError:
pass # CASA not available
6 changes: 1 addition & 5 deletions src/samrfi/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm
from transformers import Sam2Model, Sam2Processor

from samrfi.data import AdaptivePatcher, Preprocessor, SAMDataset
from samrfi.data import AdaptivePatcher, MSLoader, Preprocessor, SAMDataset
from samrfi.utils import logger
from samrfi.utils.errors import CheckpointMismatchError

Expand Down Expand Up @@ -554,8 +554,6 @@ def predict_ms(
Returns:
Predicted flags array (baselines, pols, channels, times)
"""
from samrfi.data.ms_loader import MSLoader

logger.info(f"\n{'='*60}")
logger.info("RFI Prediction - Single Pass")
logger.info(f"{'='*60}")
Expand Down Expand Up @@ -677,8 +675,6 @@ def predict_iterative(
Returns:
Cumulative flags from all iterations
"""
from samrfi.data.ms_loader import MSLoader

logger.info(f"\n{'='*60}")
logger.info(f"RFI Prediction - Iterative ({num_iterations} passes)")
logger.info(f"{'='*60}")
Expand Down
Loading