Skip to content

An autonomous Pong agent built with Python, PyTorch, and Pygame. It uses supervised learning to train both reactive (MLP) and strategic (LSTM) neural networks by imitating human gameplay. Collect data, train a model, and watch it play

Notifications You must be signed in to change notification settings

Rayeed221/pong_ai

Repository files navigation

CNN-Based Autonomous Pong Agent

A supervised learning project that trains neural networks to play Pong by imitating human gameplay strategies.

🆕 Version 2.0 - LSTM Support

This version adds temporal learning via LSTM, enabling the model to:

  • Learn from sequences of past frames (not just current state)
  • Understand opponent movement patterns
  • Develop multi-step strategic planning
  • Discover deceptive play patterns

Project Overview

This project implements a complete pipeline for:

  1. Data Collection: Play Pong and record strategic gameplay data (frames + sequences)
  2. Model Training: Train MLP (reactive) or LSTM (strategic) neural networks
  3. Deployment: Deploy the trained model to play autonomously

Model Comparison

Feature MLP (Reactive) LSTM (Strategic)
Input Single frame (6 features) Sequence of 10 frames
Memory None Temporal history
Strategy React to current state Plan based on patterns
Training Data ~5,000 frames ~1,000 sequences
Parameters ~12,000 ~100,000
Use Case Fast, simple gameplay Complex, deceptive play

Project Structure

pong_ai_project/
├── pong_game.py          # Game environment + data collection
├── pong_agent.py         # Neural network architectures (MLP + LSTM)
├── train_model.py        # Training pipeline
├── play_against_ai.py    # AI deployment script
├── requirements.txt      # Python dependencies
├── README.md             # This file
├── data/                 # Recorded gameplay data
│   ├── *.csv             # Frame data (MLP)
│   └── *_sequences.pkl   # Sequence data (LSTM)
└── models/               # Trained model checkpoints
    ├── best_model.pth        # Best MLP model
    └── best_model_lstm.pth   # Best LSTM model

Quick Start

1. Install Dependencies

pip install -r requirements.txt

2. Collect Training Data

Play the game and record your gameplay:

python pong_game.py

Controls:

  • W/S or ↑/↓ - Move paddle
  • SPACE - Pause/Resume
  • R - Toggle recording
  • Q - Save data and quit (saves both CSV and sequences)
  • ESC - Quit without saving

Tips for Strategic Data:

  • Play deceptively - fake one direction, hit another
  • Vary your timing and positioning
  • Try to predict and counter opponent patterns
  • Collect 15-30 minutes for good LSTM training

3. Train the Model

Train MLP (reactive, fast):

python train_model.py --model-type mlp

Train LSTM (strategic, recommended):

python train_model.py --model-type lstm --epochs 150

LSTM Options:

python train_model.py --model-type lstm \
    --hidden-size 128 \
    --num-layers 2 \
    --sequence-length 10 \
    --epochs 150

4. Test the AI

With MLP:

python play_against_ai.py --model models/best_model.pth

With LSTM:

python play_against_ai.py --model models/best_model_lstm.pth

Game Modes:

# AI vs Simple Opponent (default)
python play_against_ai.py --mode ai_vs_opponent

# Human vs Trained AI
python play_against_ai.py --mode human_vs_ai

# AI vs AI (self-play)
python play_against_ai.py --mode ai_vs_ai

Technical Details

Feature Space (6 normalized features)

Feature Description Normalization
ball_x Ball horizontal position [0, 800] → [0, 1]
ball_y Ball vertical position [0, 600] → [0, 1]
ball_dx Ball horizontal velocity [-12, 12] → [0, 1]
ball_dy Ball vertical velocity [-12, 12] → [0, 1]
player_y Player paddle position [0, 600] → [0, 1]
opponent_y Opponent paddle position [0, 600] → [0, 1]

LSTM Architecture

Input (10 timesteps × 6 features)
    ↓
LSTM Layer 1 (128 hidden, dropout=0.3)
    ↓
LSTM Layer 2 (128 hidden, dropout=0.3)
    ↓
Attention Mechanism (weighted timestep importance)
    ↓
FC Layer (128 → 64 → 32 → 3)
    ↓
Output (UP, DOWN, STAY)

Key Innovation - Attention Mechanism:

  • Learns which past frames are most relevant
  • Visualized in debug mode (green = recent, blue = older)
  • Enables focus on critical moments (ball approaching, opponent moving)

What LSTM Can Learn

Temporal Patterns:

  • "Opponent always moves UP after hitting high"
  • "Ball will reach me in ~10 frames at this speed"

Opponent Modeling:

  • Predict opponent's next move based on history
  • Exploit reaction delays and patterns

Strategic Sequences:

  • "If I hit low twice, opponent expects low → hit high"
  • Multi-step deceptive plays

Still Cannot Learn (needs RL):

  • Completely novel strategies not in training data
  • Optimal play against perfect opponents
  • Self-improvement through trial and error

Data Collection Strategy

Frame Data (MLP)

  • Records individual frames during rallies
  • Labels with success/failure based on rally outcome
  • Saved as CSV files

Sequence Data (LSTM)

  • Extracts sliding windows of 10 consecutive frames
  • Each sequence labeled with the action at the final frame
  • Saved as pickle files (efficient numpy arrays)

Performance Tips

Data Collection

  • Play strategically, not randomly
  • Use varied tactics (aggressive, defensive, deceptive)
  • Collect at least 15 minutes for LSTM
  • More variety = better generalization

Training

  • LSTM needs more epochs than MLP (100-200)
  • Use gradient clipping (automatic)
  • Monitor attention weights for learning progress
  • Lower learning rate if loss is unstable

Expected Performance

Model Data Win Rate vs Imperfect AI
MLP 10k frames 60-70%
MLP 50k frames 70-80%
LSTM 5k sequences 70-80%
LSTM 20k sequences 80-90%

Troubleshooting

"No sequence files found"

  • Make sure you've saved data with Q key
  • Sequence files are created automatically from frame data
  • Check for *_sequences.pkl files in data/ directory

LSTM training is slow

  • Reduce --batch-size if GPU memory is limited
  • Reduce --hidden-size (64 instead of 128)
  • Use fewer --num-layers (1 instead of 2)

Low accuracy after training

  • Collect more diverse training data
  • Ensure you're using strategic play during collection
  • Try different sequence lengths (5, 10, 15)

Future Enhancements (Phase 3)

  • Reinforcement learning fine-tuning
  • Self-play improvement
  • Bidirectional LSTM
  • Transformer architecture
  • Opponent adaptation during gameplay

License

This project is for educational purposes.

About

An autonomous Pong agent built with Python, PyTorch, and Pygame. It uses supervised learning to train both reactive (MLP) and strategic (LSTM) neural networks by imitating human gameplay. Collect data, train a model, and watch it play

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages