This repository contains a fine-tuned RT-1 (Robotic Transformer) model that predicts low-level execution parameters for food acquisition tasks. Given an image of a food item and a natural language instruction, the model outputs execution parameters [x_1, y_1, x_2, y_2] that specify how to perform the manipulation action. The model supports three manipulation actions: skewering, scooping, and twirling.
There are two main scripts to run:
- Training:
slurm_scripts/run_training.slurm- Starts the fine-tuning process - Inference Server:
slurm_scripts/api_server.slurm- Runs the API server for predictions
sbatch slurm_scripts/run_training.slurmThis script calls train.py.
sbatch slurm_scripts/api_server.slurmThis starts a Flask API server that loads a trained checkpoint and serves predictions via HTTP endpoints.
The project uses a conda environment. Package versions are specified in package_verisons.txt. To recreate the environment:
Key dependencies:
- TensorFlow 2.11.0
- TF-Agents 0.15.0
- Gin-Config 0.5.0
- Flask 3.1.2
- NumPy 1.21.6
- OpenCV 4.7.0.72
The model architecture follows this chain:
-
agent_wrapper.py- Top-level wrapper that creates the Agent model- Wraps
ActorNetworkto match checkpoint variable naming hierarchy - Ensures variable names follow pattern:
agent/_actor_network/...
- Wraps
-
transformer_network.py- Main transformer-based actor network- Implements
TransformerNetworkclass that combines image and action tokens - Uses
RT1ImageTokenizerandRT1ActionTokenizerfor tokenization - Processes tokens through transformer layers with causal attention masks
- Implements
-
sequence_agent.py- TF-Agents wrapperSequenceAgent: Wraps the transformer network for training/inferenceSequencePolicy: Policy that outputs actions via actor network- Handles loss computation and gradient updates
-
transformer.py- Transformer implementationTransformer: Multi-layer transformer decoder_TransformerLayer: Single transformer block with multi-head attention and feed-forward layers
-
tokenizers/image_tokenizer.py- Image tokenizationRT1ImageTokenizer: Converts images to tokens using EfficientNet + TokenLearner- Uses FiLM (Feature-wise Linear Modulation) for natural language conditioning
-
tokenizers/action_tokenizer.py- Action tokenizationRT1ActionTokenizer: Converts execution parameters (4D coordinates) to discrete tokens- Supports both tokenization (action → tokens) and detokenization (tokens → action)
-
tokenizers/token_learner.py- Token learning module- Learns to extract compact token representations from image features
film_efficientnet/- EfficientNet encoder with FiLM conditioningpretrained_efficientnet_encoder.py: Loads pretrained EfficientNet weightsfilm_conditioning_layer.py: FiLM layer for language conditioningfilm_efficientnet_encoder.py: Combines EfficientNet with FiLM
-
train.py- Main training script- Loads dataset from
.npzfile - Creates train/validation splits
- Initializes agent with pretrained checkpoint
- Runs training loop with validation
- Saves best checkpoints based on validation metrics (centroid MAE for execution parameters)
- Loads dataset from
-
utils.py- Utility functions for trainingcreate_agent(): Creates and configures the agentcreate_npz_dataset(): Loads and preprocesses datasettrain_step(): Single training stepvalidate_step_with_metrics(): Validation with execution parameter metricsload_pretrained_checkpoint()/load_resume_checkpoint(): Checkpoint loadingapply_freezing_logic(): Freeze/unfreeze model components- Data augmentation functions (rotation, lighting)
-
test_checkpoint_loading.py- Checkpoint loading utilitiescustom_load_checkpoint(): Custom checkpoint loading with variable name mapping- Handles mapping between checkpoint variable names and model variable names
- Supports loading pretrained RT-1 checkpoints
-
api_batch_server.py- Flask API server for inference/predict: Single prediction endpoint/predict_batch: Batch prediction with Monte Carlo Dropout for uncertainty/available_combinations: List available action-food combinations/load_model: Load/reload model checkpoint- Uses pre-computed embeddings from
embeddings/directory
-
api_server.py- Alternative API server (simpler version)- Similar to
api_batch_server.pybut without batch processing
- Similar to
-
api_client.py- Client for sending API requests- Example client code for interacting with the API server
-
batch_api_inference.py- Batch inference script- Processes multiple images from a directory
- Saves results to JSON file
- Can be used for evaluation on test sets
-
dropout_settings.py- Dropout configuration- Controls dropout rates for different components
- Enables/disables dropout during inference
-
policy_specs.pbtxt- Policy specifications (protobuf format)
train.py
↓
utils.create_npz_dataset() → Loads .npz file, creates train/val splits
↓
utils.create_agent() → Creates Agent model
↓
agent_wrapper.Agent → Wraps ActorNetwork
↓
transformer_network.TransformerNetwork → Main network
↓
├─→ image_tokenizer.RT1ImageTokenizer → EfficientNet + TokenLearner
├─→ action_tokenizer.RT1ActionTokenizer → Execution parameters → tokens
└─→ transformer.Transformer → Processes tokens
↓
Training loop: train_step() → loss → gradients → optimizer
↓
Checkpoint saving (best model based on validation metrics)
api_batch_server.py
↓
Load checkpoint → test_checkpoint_loading.custom_load_checkpoint()
↓
Create agent → agent_wrapper.Agent
↓
Receive image + action/food_type
↓
Load embedding from embeddings/ directory
↓
agent.forward() → transformer_network.TransformerNetwork
↓
├─→ image_tokenizer → image tokens
├─→ action_tokenizer → detokenize → execution parameters (4D coordinates)
└─→ transformer → process sequence
↓
Return execution parameters (start_x, start_y, end_x, end_y)
- Dataset File:
/share/bhattacharjee/frank_data/dataset_sk_sc.npz - Data Directory:
/share/bhattacharjee/frank_data - Train-Val Split:
/share/bhattacharjee/frank_data/train_val_split_sk_sc.npz
The dataset contains:
- Images: Food item images (236x236 pixels)
- Execution parameters: 4D coordinates (x1, y1, x2, y2) normalized to [0, 1], representing start and end points for manipulation actions
- Natural language embeddings: 512-dimensional embeddings for action-food combinations
Actions: skewering, scooping, twirling
Food Items: banana, blueberry, broccoli, brownie, cantaloupe, cherry_tomato, chicken, fettucine, grape, green_bean, honeydew, lettuce, mac_and_cheese, mashed_potato, meatball, oatmeal, pineapple, rice, sausage, spaghetti, strawberry, watermelon
Pre-computed embeddings for all action-food combinations are stored in the embeddings/ directory.
The model is based on RT-1 architecture:
- Image Encoder: EfficientNet-B3 with FiLM conditioning
- TokenLearner: Learns compact token representations (8 tokens per image)
- Image Tokenizer: Combines EfficientNet features with TokenLearner
- Action Tokenizer: Discretizes execution parameters (4D coordinates) into tokens (vocab_size=256)
- Transformer: 8-layer transformer with causal attention masks
- Output: Action token logits → detokenized to execution parameters (4D coordinates: start_x, start_y, end_x, end_y)
Training can be configured via command-line flags in train.py:
--config_file: Gin config file (default:configs/transformer_mixin.gin)--checkpoint_dir: Output directory for checkpoints--pretrained_checkpoint: Path to pretrained RT-1 checkpoint--resume_checkpoint: Path to checkpoint to resume from--batch_size: Training batch size (default: 16)--num_epochs: Number of epochs (default: 30)--loss_type: Loss function -cross_entropy,l1, orsmooth_l1(default:smooth_l1)--freeze_efficientnet,--freeze_tokenlearner,--freeze_transformer: Freeze components
Once the API server is running, you can send requests:
import requests
import base64
# Load image
with open('image.jpg', 'rb') as f:
image_data = base64.b64encode(f.read()).decode('utf-8')
# Send prediction request
response = requests.post('http://localhost:8080/predict', json={
'image': image_data,
'action': 'skewering',
'food_type': 'banana'
})
result = response.json()
prediction = result['prediction']
# Contains: start_x, start_y, end_x, end_y, skewer_x, skewer_y, direction_degrees
execution_params = [prediction['start_x'], prediction['start_y'],
prediction['end_x'], prediction['end_y']]