VLA Model Architectures¶
This page explores the architectural components and design patterns for Vision-Language-Action models.
Overview¶
A typical VLA architecture consists of four main components:
graph TD
A[Visual Inputs] --> B[Vision Encoder]
C[Language Instructions] --> D[Language Encoder]
E[Robot State] --> F[State Encoder]
B --> G[Multi-Modal Fusion]
D --> G
F --> G
G --> H[Action Decoder]
H --> I[Robot Actions]
Vision Encoder¶
The vision encoder processes visual observations from cameras and sensors.
Popular Architectures¶
from transformers import ViTModel
class ViTVisionEncoder:
def __init__(self, pretrained='google/vit-base-patch16-224'):
self.encoder = ViTModel.from_pretrained(pretrained)
self.projection = nn.Linear(768, hidden_dim)
def forward(self, images):
# images: (batch, channels, height, width)
outputs = self.encoder(images)
features = outputs.last_hidden_state # (batch, num_patches, 768)
return self.projection(features)
Advantages: - Strong pre-trained representations - Captures global context - Scalable to large datasets
import torchvision.models as models
class ResNetVisionEncoder:
def __init__(self, pretrained=True):
resnet = models.resnet50(pretrained=pretrained)
# Remove final classification layer
self.encoder = nn.Sequential(*list(resnet.children())[:-2])
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.projection = nn.Linear(2048, hidden_dim)
def forward(self, images):
features = self.encoder(images) # (batch, 2048, H, W)
pooled = self.pool(features).squeeze(-1).squeeze(-1)
return self.projection(pooled)
Advantages: - Proven architecture - Efficient inference - Strong spatial features
import open_clip
class CLIPVisionEncoder:
def __init__(self):
model, _, preprocess = open_clip.create_model_and_transforms(
'ViT-B-32',
pretrained='laion2b_s34b_b79k'
)
self.encoder = model.visual
def forward(self, images):
# Leverages vision-language pre-training
return self.encoder(images)
Advantages: - Vision-language alignment - Strong zero-shot capabilities - Internet-scale pre-training
Multi-View Processing¶
For robots with multiple cameras:
class MultiViewVisionEncoder:
def __init__(self, single_view_encoder):
self.view_encoders = nn.ModuleList([
single_view_encoder() for _ in range(num_views)
])
self.fusion = nn.MultiheadAttention(hidden_dim, num_heads=8)
def forward(self, multi_view_images):
# multi_view_images: dict with keys like 'front', 'wrist', 'top'
view_features = []
for view_name, encoder in zip(multi_view_images.keys(), self.view_encoders):
features = encoder(multi_view_images[view_name])
view_features.append(features)
# Stack and fuse with attention
stacked_features = torch.stack(view_features, dim=1)
fused, _ = self.fusion(stacked_features, stacked_features, stacked_features)
return fused
Language Encoder¶
Processes natural language instructions and task descriptions.
Architecture Options¶
from transformers import T5EncoderModel, T5Tokenizer
class T5LanguageEncoder:
def __init__(self):
self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
self.encoder = T5EncoderModel.from_pretrained('t5-base')
def forward(self, instructions):
# instructions: List[str]
inputs = self.tokenizer(
instructions,
padding=True,
return_tensors='pt'
)
outputs = self.encoder(**inputs)
return outputs.last_hidden_state # (batch, seq_len, 768)
from transformers import BertModel, BertTokenizer
class BERTLanguageEncoder:
def __init__(self):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.encoder = BertModel.from_pretrained('bert-base-uncased')
def forward(self, instructions):
inputs = self.tokenizer(
instructions,
padding=True,
truncation=True,
return_tensors='pt'
)
outputs = self.encoder(**inputs)
# Use [CLS] token representation
return outputs.last_hidden_state[:, 0, :] # (batch, 768)
import open_clip
class CLIPTextEncoder:
def __init__(self):
model, _, _ = open_clip.create_model_and_transforms(
'ViT-B-32',
pretrained='laion2b_s34b_b79k'
)
self.encoder = model.encode_text
self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
def forward(self, instructions):
tokens = self.tokenizer(instructions)
return self.encoder(tokens)
Multi-Modal Fusion¶
Combining visual and language representations is critical for VLA performance.
Cross-Attention Fusion¶
class CrossAttentionFusion:
def __init__(self, hidden_dim, num_heads=8, num_layers=4):
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4
)
for _ in range(num_layers)
])
def forward(self, visual_features, language_features, state_features):
# visual_features: (batch, num_patches, hidden_dim)
# language_features: (batch, seq_len, hidden_dim)
# state_features: (batch, state_dim)
# Combine all modalities
query = torch.cat([
visual_features,
language_features,
state_features.unsqueeze(1)
], dim=1)
# Self-attention and cross-attention
for layer in self.layers:
query = layer(query, query)
return query
Gated Fusion¶
class GatedFusion:
def __init__(self, visual_dim, language_dim, output_dim):
self.visual_proj = nn.Linear(visual_dim, output_dim)
self.language_proj = nn.Linear(language_dim, output_dim)
self.gate = nn.Sequential(
nn.Linear(visual_dim + language_dim, output_dim),
nn.Sigmoid()
)
def forward(self, visual_features, language_features):
v = self.visual_proj(visual_features)
l = self.language_proj(language_features)
# Compute gate
gate = self.gate(torch.cat([visual_features, language_features], dim=-1))
# Gated fusion
fused = gate * v + (1 - gate) * l
return fused
FiLM (Feature-wise Linear Modulation)¶
class FiLMFusion:
def __init__(self, visual_dim, language_dim):
# Language generates scaling and shifting parameters
self.gamma_net = nn.Linear(language_dim, visual_dim)
self.beta_net = nn.Linear(language_dim, visual_dim)
def forward(self, visual_features, language_features):
# Generate modulation parameters from language
gamma = self.gamma_net(language_features) # Scaling
beta = self.beta_net(language_features) # Shifting
# Modulate visual features
modulated = gamma * visual_features + beta
return modulated
Action Decoder¶
Generates robot actions from fused multi-modal representations.
Action Prediction Heads¶
class DiffusionActionDecoder:
def __init__(self, input_dim, action_dim, num_steps=100):
self.noise_predictor = UNet1D(
input_dim=input_dim + action_dim,
output_dim=action_dim
)
self.num_steps = num_steps
def forward(self, fused_features, actions=None):
if self.training:
# Training: add noise and predict
timesteps = torch.randint(0, self.num_steps, (batch_size,))
noise = torch.randn_like(actions)
noisy_actions = self.add_noise(actions, noise, timesteps)
# Predict noise
predicted_noise = self.noise_predictor(
torch.cat([fused_features, noisy_actions], dim=-1),
timesteps
)
return predicted_noise
else:
# Inference: denoise from random noise
actions = torch.randn(batch_size, action_dim)
for t in reversed(range(self.num_steps)):
predicted_noise = self.noise_predictor(
torch.cat([fused_features, actions], dim=-1),
t
)
actions = self.denoise_step(actions, predicted_noise, t)
return actions
class AutoregressiveActionDecoder:
def __init__(self, input_dim, action_dim, chunk_size=10):
self.chunk_size = chunk_size
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=256, nhead=8),
num_layers=6
)
self.action_embedding = nn.Linear(action_dim, 256)
self.output_projection = nn.Linear(256, action_dim)
def forward(self, fused_features, past_actions=None):
# Generate action sequence autoregressively
if past_actions is None:
past_actions = torch.zeros(batch_size, 1, action_dim)
action_embeds = self.action_embedding(past_actions)
decoder_output = self.transformer(
action_embeds,
fused_features.unsqueeze(1)
)
predicted_actions = self.output_projection(decoder_output)
return predicted_actions
Complete VLA Architecture¶
Putting it all together:
class VLAModel(nn.Module):
def __init__(self, config):
super().__init__()
# Encoders
self.vision_encoder = CLIPVisionEncoder()
self.language_encoder = T5LanguageEncoder()
self.state_encoder = nn.Linear(config.state_dim, config.hidden_dim)
# Fusion
self.fusion = CrossAttentionFusion(
hidden_dim=config.hidden_dim,
num_heads=8,
num_layers=4
)
# Action decoder
self.action_decoder = MLPActionDecoder(
input_dim=config.hidden_dim,
action_dim=config.action_dim
)
def forward(self, observations):
# Encode modalities
visual_features = self.vision_encoder(observations['image'])
language_features = self.language_encoder(observations['instruction'])
state_features = self.state_encoder(observations['robot_state'])
# Fuse
fused = self.fusion(visual_features, language_features, state_features)
# Decode actions
actions = self.action_decoder(fused[:, 0, :]) # Use first token
return actions
Architecture Variants¶
RT-1 Style¶
- Vision: EfficientNet
- Language: Universal Sentence Encoder
- Fusion: Token-based Transformer
- Action: Discretized action space
RT-2 Style¶
- Vision: ViT (from pre-trained VLM)
- Language: T5
- Fusion: Vision-Language Model backbone
- Action: Co-fine-tuned with VLM
OpenVLA Style¶
- Vision: SigLIP
- Language: Llama-based
- Fusion: Integrated multi-modal transformer
- Action: Continuous action prediction
Design Considerations¶
Model Size vs. Performance¶
| Size | Parameters | Inference Speed | Performance |
|---|---|---|---|
| Small | < 100M | Fast (>30 Hz) | Good for simple tasks |
| Medium | 100M-1B | Medium (10-30 Hz) | General purpose |
| Large | > 1B | Slow (<10 Hz) | Best generalization |
Action Representation¶
Choose based on your robot and task:
- End-effector: More intuitive, easier sim-to-real
- Joint space: More precise control
- Delta actions: More stable, less prone to drift
Next Steps¶
- Training VLA Models - Learn how to train these architectures
- Inference Guide - Deploy VLA models on robots
- LeRobot Dataset - Prepare data for training