Rate this Page

Hyperparameter tuning using Ray Tune#

Created On: Aug 31, 2020 | Last Updated: Jan 08, 2026 | Last Verified: Nov 05, 2024

Author: Ricardo Decal

This tutorial shows how to integrate Ray Tune into your PyTorch training workflow to perform scalable and efficient hyperparameter tuning.

What you will learn
  • How to modify a PyTorch training loop for Ray Tune

  • How to scale a hyperparameter sweep to multiple nodes and GPUs without code changes

  • How to define a hyperparameter search space and run a sweep with tune.Tuner

  • How to use an early-stopping scheduler (ASHA) and report metrics/checkpoints

  • How to use checkpointing to resume training and load the best model

Prerequisites
  • PyTorch v2.9+ and torchvision

  • Ray Tune (ray[tune]) v2.52.1+

  • GPU(s) are optional, but recommended for faster training

Ray, a project of the PyTorch Foundation, is an open source unified framework for scaling AI and Python applications. It helps run distributed jobs by handling the complexity of distributed computing. Ray Tune is a library built on Ray for hyperparameter tuning that enables you to scale a hyperparameter sweep from your machine to a large cluster with no code changes.

This tutorial adapts the PyTorch tutorial for training a CIFAR10 classifier to run multi-GPU hyperparameter sweeps with Ray Tune.

Setup#

To run this tutorial, install the following dependencies:

pip install "ray[tune]" torchvision

Then start with the imports:

from functools import partial
import os
import tempfile
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
# New: imports for Ray Tune
import ray
from ray import tune
from ray.tune import Checkpoint
from ray.tune.schedulers import ASHAScheduler

Data loading#

Wrap the data loaders in a constructor function. In this tutorial, a global data directory is passed to the function to enable reusing the dataset across different trials. In a cluster environment, you can use shared storage, such as network file systems, to prevent each node from downloading the data separately.

def load_data(data_dir="./data"):
    # Mean and standard deviation of the CIFAR10 training subset.
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.4914, 0.48216, 0.44653), (0.2022, 0.19932, 0.20086))]
    )

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )

    return trainset, testset

Model architecture#

This tutorial searches for the best sizes for the fully connected layers and the learning rate. To enable this, the Net class exposes the layer sizes l1 and l2 as configurable parameters that Ray Tune can search over:

class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Define the search space#

Next, define the hyperparameters to tune and how Ray Tune samples them. Ray Tune offers a variety of search space distributions to suit different parameter types: loguniform, uniform, choice, randint, grid, and more. You can also express complex dependencies between parameters with conditional search spaces or sample from arbitrary functions.

Here is the search space for this tutorial:

config = {
    "l1": tune.choice([2**i for i in range(9)]),
    "l2": tune.choice([2**i for i in range(9)]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16]),
}

The tune.choice() accepts a list of values that are uniformly sampled from. In this example, the l1 and l2 parameter values are powers of 2 between 1 and 256, and the learning rate samples on a log scale between 0.0001 and 0.1. Sampling on a log scale enables exploration across a range of magnitudes on a relative scale, rather than an absolute scale.

Training function#

Ray Tune requires a training function that accepts a configuration dictionary and runs the main training loop. As Ray Tune runs different trials, it updates the configuration dictionary for each trial.

Here is the full training function, followed by explanations of the key Ray Tune integration points:

def train_cifar(config, data_dir=None):
    net = Net(config["l1"], config["l2"])
    device = config["device"]

    net = net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    # Load checkpoint if resuming training
    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
            checkpoint_state = torch.load(checkpoint_path)
            start_epoch = checkpoint_state["epoch"]
            net.load_state_dict(checkpoint_state["net_state_dict"])
            optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch = 0

    trainset, _testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs]
    )

    trainloader = torch.utils.data.DataLoader(
        train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
    )
    valloader = torch.utils.data.DataLoader(
        val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
    )

    for epoch in range(start_epoch, 10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(
                    "[%d, %5d] loss: %.3f"
                    % (epoch + 1, i + 1, running_loss / epoch_steps)
                )
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Save checkpoint and report metrics
        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        with tempfile.TemporaryDirectory() as checkpoint_dir:
            checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
            torch.save(checkpoint_data, checkpoint_path)

            checkpoint = Checkpoint.from_directory(checkpoint_dir)
            tune.report(
                {"loss": val_loss / val_steps, "accuracy": correct / total},
                checkpoint=checkpoint,
            )

    print("Finished Training")

Key integration points#

Using hyperparameters from the configuration dictionary#

Ray Tune updates the config dictionary with the hyperparameters for each trial. In this example, the model architecture and optimizer receive the hyperparameters from the config dictionary:

net = Net(config["l1"], config["l2"])
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

Reporting metrics and saving checkpoints#

The most important integration is communicating with Ray Tune. Ray Tune uses the validation metrics to determine the best hyperparameter configuration and to stop underperforming trials early, saving resources.

Checkpointing enables you to later load the trained models, resume hyperparameter searches, and provides fault tolerance. It’s also required for some Ray Tune schedulers like Population Based Training that pause and resume trials during the search.

This code from the training function loads model and optimizer state at the start if a checkpoint exists:

checkpoint = tune.get_checkpoint()
if checkpoint:
    with checkpoint.as_directory() as checkpoint_dir:
        checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
        checkpoint_state = torch.load(checkpoint_path)
        start_epoch = checkpoint_state["epoch"]
        net.load_state_dict(checkpoint_state["net_state_dict"])
        optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])

At the end of each epoch, save a checkpoint and report the validation metrics:

checkpoint_data = {
    "epoch": epoch,
    "net_state_dict": net.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}
with tempfile.TemporaryDirectory() as checkpoint_dir:
    checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
    torch.save(checkpoint_data, checkpoint_path)

    checkpoint = Checkpoint.from_directory(checkpoint_dir)
    tune.report(
        {"loss": val_loss / val_steps, "accuracy": correct / total},
        checkpoint=checkpoint,
    )

Ray Tune checkpointing supports local file systems, cloud storage, and distributed file systems. For more information, see the Ray Tune storage documentation.

Multi-GPU support#

Image classification models can be greatly accelerated by using GPUs. The training function supports multi-GPU training by wrapping the model in nn.DataParallel:

This training function supports training on CPUs, a single GPU, multiple GPUs, or multiple nodes without code changes. Ray Tune automatically distributes the trials across the nodes according to the available resources. Ray Tune also supports fractional GPUs so that one GPU can be shared among multiple trials, provided that the models, optimizers, and data batches fit into the GPU memory.

Validation split#

The original CIFAR10 dataset only has train and test subsets. This is sufficient for training a single model, however for hyperparameter tuning a validation subset is required. The training function creates a validation subset by reserving 20% of the training subset. The test subset is used to evaluate the best model’s generalization error after the search completes.

Evaluation function#

After finding the optimal hyperparameters, test the model on a held-out test set to estimate the generalization error:

def test_accuracy(net, device="cpu", data_dir=None):
    _trainset, testset = load_data(data_dir)

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2
    )

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            image_batch, labels = data
            image_batch, labels = image_batch.to(device), labels.to(device)
            outputs = net(image_batch)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

Configure and run Ray Tune#

With the training and evaluation functions defined, configure Ray Tune to run the hyperparameter search.

Scheduler for early stopping#

Ray Tune provides schedulers to improve the efficiency of the hyperparameter search by detecting underperforming trials and stopping them early. The ASHAScheduler uses the Asynchronous Successive Halving Algorithm (ASHA) to aggressively terminate low-performing trials:

scheduler = ASHAScheduler(
    max_t=max_num_epochs,
    grace_period=1,
    reduction_factor=2,
)

Ray Tune also provides advanced search algorithms to smartly pick the next set of hyperparameters based on previous results, instead of relying only on random or grid search. Examples include Optuna and BayesOpt.

Resource allocation#

Tell Ray Tune what resources to allocate for each trial by passing a resources dictionary to tune.with_resources:

tune.with_resources(
    partial(train_cifar, data_dir=data_dir),
    resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
)

Ray Tune automatically manages the placement of these trials and ensures that the trials run in isolation, so you don’t need to manually assign GPUs to processes.

For example, if you are running this experiment on a cluster of 20 machines, each with 8 GPUs, you can set gpus_per_trial = 0.5 to schedule two concurrent trials per GPU. This configuration runs 320 trials in parallel across the cluster.

Note

To run this tutorial without GPUs, set gpus_per_trial=0 and expect significantly longer runtimes.

To avoid long runtimes during development, start with a small number of trials and epochs.

Creating the Tuner#

The Ray Tune API is modular and composable. Pass your configuration to the tune.Tuner class to create a tuner object, then run tuner.fit() to start training:

tuner = tune.Tuner(
    tune.with_resources(
        partial(train_cifar, data_dir=data_dir),
        resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
    ),
    tune_config=tune.TuneConfig(
        metric="loss",
        mode="min",
        scheduler=scheduler,
        num_samples=num_trials,
    ),
    param_space=config,
)
results = tuner.fit()

After training completes, retrieve the best performing trial, load its checkpoint, and evaluate on the test set.

Putting it all together#

def main(num_trials=10, max_num_epochs=10, gpus_per_trial=0, cpus_per_trial=2):
    print("Starting hyperparameter tuning.")
    ray.init(include_dashboard=False)

    data_dir = os.path.abspath("./data")
    load_data(data_dir)  # Pre-download the dataset
    device = "cuda" if torch.cuda.is_available() else "cpu"
    config = {
        "l1": tune.choice([2**i for i in range(9)]),
        "l2": tune.choice([2**i for i in range(9)]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16]),
        "device": device,
    }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2,
    )

    tuner = tune.Tuner(
        tune.with_resources(
            partial(train_cifar, data_dir=data_dir),
            resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_trials,
        ),
        param_space=config,
    )
    results = tuner.fit()

    best_result = results.get_best_result("loss", "min")
    print(f"Best trial config: {best_result.config}")
    print(f"Best trial final validation loss: {best_result.metrics['loss']}")
    print(f"Best trial final validation accuracy: {best_result.metrics['accuracy']}")

    best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
    best_trained_model = best_trained_model.to(device)
    if gpus_per_trial > 1:
        best_trained_model = nn.DataParallel(best_trained_model)

    best_checkpoint = best_result.checkpoint
    with best_checkpoint.as_directory() as checkpoint_dir:
        checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
        best_checkpoint_data = torch.load(checkpoint_path)

        best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
        test_acc = test_accuracy(best_trained_model, device, data_dir)
        print(f"Best trial test set accuracy: {test_acc}")


if __name__ == "__main__":
    # Set the number of trials, epochs, and GPUs per trial here:
    main(num_trials=10, max_num_epochs=10, gpus_per_trial=1)
Starting hyperparameter tuning.
2026-01-16 22:30:42,326 WARNING services.py:2137 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 2147471360 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2026-01-16 22:30:42,495 INFO worker.py:2023 -- Started a local Ray instance.
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2062: FutureWarning:

Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0


  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 426k/170M [00:00<00:41, 4.05MB/s]
  1%|▏         | 2.49M/170M [00:00<00:12, 13.6MB/s]
  3%|▎         | 5.01M/170M [00:00<00:08, 18.7MB/s]
  5%|▍         | 8.00M/170M [00:00<00:07, 23.0MB/s]
  6%|▋         | 11.0M/170M [00:00<00:06, 25.4MB/s]
  8%|▊         | 14.0M/170M [00:00<00:05, 26.9MB/s]
 10%|▉         | 17.0M/170M [00:00<00:05, 27.9MB/s]
 12%|█▏        | 20.0M/170M [00:00<00:05, 28.5MB/s]
 13%|█▎        | 22.9M/170M [00:00<00:05, 28.9MB/s]
 15%|█▌        | 25.9M/170M [00:01<00:04, 29.2MB/s]
 17%|█▋        | 28.9M/170M [00:01<00:04, 29.3MB/s]
 19%|█▉        | 33.1M/170M [00:01<00:04, 33.1MB/s]
 23%|██▎       | 38.6M/170M [00:01<00:03, 39.6MB/s]
 27%|██▋       | 45.4M/170M [00:01<00:02, 48.2MB/s]
 32%|███▏      | 53.8M/170M [00:01<00:01, 59.0MB/s]
 38%|███▊      | 64.6M/170M [00:01<00:01, 73.4MB/s]
 45%|████▍     | 76.3M/170M [00:01<00:01, 86.5MB/s]
 52%|█████▏    | 87.9M/170M [00:01<00:00, 95.6MB/s]
 58%|█████▊    | 99.6M/170M [00:01<00:00, 102MB/s]
 65%|██████▌   | 111M/170M [00:02<00:00, 106MB/s]
 72%|███████▏  | 123M/170M [00:02<00:00, 109MB/s]
 79%|███████▉  | 135M/170M [00:02<00:00, 111MB/s]
 86%|████████▌ | 146M/170M [00:02<00:00, 113MB/s]
 93%|█████████▎| 158M/170M [00:02<00:00, 114MB/s]
 99%|█████████▉| 170M/170M [00:02<00:00, 115MB/s]
100%|██████████| 170M/170M [00:02<00:00, 67.6MB/s]
╭────────────────────────────────────────────────────────────────────╮
│ Configuration for experiment     train_cifar_2026-01-16_22-30-48   │
├────────────────────────────────────────────────────────────────────┤
│ Search algorithm                 BasicVariantGenerator             │
│ Scheduler                        AsyncHyperBandScheduler           │
│ Number of trials                 10                                │
╰────────────────────────────────────────────────────────────────────╯

View detailed results here: /var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2026-01-16_22-30-41_075213_3477/artifacts/2026-01-16_22-30-48/train_cifar_2026-01-16_22-30-48/driver_artifacts`

Trial status: 10 PENDING
Current time: 2026-01-16 22:30:49. Total running time: 0s
Logical resource usage: 0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
╭───────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size │
├───────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   PENDING      16     16   0.000333728              8 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2 │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2 │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16 │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4 │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16 │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4 │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16 │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8 │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8 │
╰───────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00000 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00000 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     8 │
│ device                                      cuda │
│ l1                                            16 │
│ l2                                            16 │
│ lr                                       0.00033 │
╰──────────────────────────────────────────────────╯
(func pid=4591) [1,  2000] loss: 2.219
(func pid=4591) [1,  4000] loss: 0.961
(pid=gcs_server) [2026-01-16 22:31:11,404 E 3482 3482] (gcs_server) gcs_server.cc:303: Failed to establish connection to the event+metrics exporter agent. Events and metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000000)
(raylet) [2026-01-16 22:31:12,429 E 3622 3622] (raylet) main.cc:979: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(bundle_reservation_check_func pid=3697) [2026-01-16 22:31:13,106 E 3697 4022] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=4591) [2,  2000] loss: 1.705

Trial status: 1 RUNNING | 9 PENDING
Current time: 2026-01-16 22:31:19. Total running time: 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.7284736596107484 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   RUNNING      16     16   0.000333728              8        1            18.4922   1.72847       0.3472 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=4591) [2026-01-16 22:31:19,908 E 4591 4626] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14 [repeated 14x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(func pid=4591) [2,  4000] loss: 0.816
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000001)
(func pid=4591) [3,  2000] loss: 1.521
(func pid=4591) [3,  4000] loss: 0.729
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000002)
Trial status: 1 RUNNING | 9 PENDING
Current time: 2026-01-16 22:31:49. Total running time: 1min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.4108793731212617 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   RUNNING      16     16   0.000333728              8        3            51.6524   1.41088       0.4779 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=4591) [4,  2000] loss: 1.385
(func pid=4591) [4,  4000] loss: 0.676
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000003)
(func pid=4591) [5,  2000] loss: 1.304
(func pid=4591) [5,  4000] loss: 0.650
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000004)
Trial status: 1 RUNNING | 9 PENDING
Current time: 2026-01-16 22:32:19. Total running time: 1min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.2982678125858307 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   RUNNING      16     16   0.000333728              8        5            84.5807   1.29827       0.5252 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=4591) [6,  2000] loss: 1.245
(func pid=4591) [6,  4000] loss: 0.621
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000005)
(func pid=4591) [7,  2000] loss: 1.186
(func pid=4591) [7,  4000] loss: 0.601
Trial status: 1 RUNNING | 9 PENDING
Current time: 2026-01-16 22:32:49. Total running time: 2min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.2507929501771926 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   RUNNING      16     16   0.000333728              8        6            100.898   1.25079       0.5501 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000006)
(func pid=4591) [8,  2000] loss: 1.169
(func pid=4591) [8,  4000] loss: 0.582
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000007)
(func pid=4591) [9,  2000] loss: 1.126
(func pid=4591) [9,  4000] loss: 0.569
Trial status: 1 RUNNING | 9 PENDING
Current time: 2026-01-16 22:33:19. Total running time: 2min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1945622762084007 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   RUNNING      16     16   0.000333728              8        8            133.585   1.19456       0.5719 │
│ train_cifar_01e85_00001   PENDING       1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00002   PENDING     256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING      32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING     256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING      32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING     256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING       4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING     128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING       4      1   0.092961                 8                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000008)
(func pid=4591) [10,  2000] loss: 1.104
(func pid=4591) [10,  4000] loss: 0.558

Trial train_cifar_01e85_00000 completed after 10 iterations at 2026-01-16 22:33:39. Total running time: 2min 50s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00000 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000009 │
│ time_this_iter_s                                  16.45474 │
│ time_total_s                                     166.71371 │
│ training_iteration                                      10 │
│ accuracy                                            0.5875 │
│ loss                                               1.15531 │
╰────────────────────────────────────────────────────────────╯
(func pid=4591) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00000_0_batch_size=8,l1=16,l2=16,lr=0.0003_2026-01-16_22-30-49/checkpoint_000009)

Trial train_cifar_01e85_00001 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00001 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     2 │
│ device                                      cuda │
│ l1                                             1 │
│ l2                                           256 │
│ lr                                       0.00338 │
╰──────────────────────────────────────────────────╯

Trial status: 1 TERMINATED | 1 RUNNING | 8 PENDING
Current time: 2026-01-16 22:33:49. Total running time: 3min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00001   RUNNING         1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10            166.714   1.15531       0.5875 │
│ train_cifar_01e85_00002   PENDING       256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING        32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=5321) [1,  2000] loss: 2.411
(func pid=5321) [1,  4000] loss: 1.170
(func pid=5321) [1,  6000] loss: 0.773
(func pid=5321) [1,  8000] loss: 0.578
(func pid=5321) [2026-01-16 22:34:10,798 E 5321 5356] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=5321) [1, 10000] loss: 0.461
(func pid=5321) [1, 12000] loss: 0.385
Trial status: 1 TERMINATED | 1 RUNNING | 8 PENDING
Current time: 2026-01-16 22:34:19. Total running time: 3min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00001   RUNNING         1    256   0.00338356               2                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10            166.714   1.15531       0.5875 │
│ train_cifar_01e85_00002   PENDING       256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING        32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=5321) [1, 14000] loss: 0.330
(func pid=5321) [1, 16000] loss: 0.288
(func pid=5321) [1, 18000] loss: 0.256
(func pid=5321) [1, 20000] loss: 0.231

Trial train_cifar_01e85_00001 completed after 1 iterations at 2026-01-16 22:34:48. Total running time: 3min 59s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00001 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000000 │
│ time_this_iter_s                                  64.51299 │
│ time_total_s                                      64.51299 │
│ training_iteration                                       1 │
│ accuracy                                            0.0988 │
│ loss                                               2.30847 │
╰────────────────────────────────────────────────────────────╯
(func pid=5321) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00001_1_batch_size=2,l1=1,l2=256,lr=0.0034_2026-01-16_22-30-49/checkpoint_000000)

Trial status: 2 TERMINATED | 8 PENDING
Current time: 2026-01-16 22:34:49. Total running time: 4min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10            166.714   1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1             64.513   2.30847       0.0988 │
│ train_cifar_01e85_00002   PENDING       256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00003   PENDING        32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00002 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00002 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     2 │
│ device                                      cuda │
│ l1                                           256 │
│ l2                                            32 │
│ lr                                       0.03111 │
╰──────────────────────────────────────────────────╯
(func pid=5456) [1,  2000] loss: 2.337
(func pid=5456) [1,  4000] loss: 1.170
(func pid=5456) [1,  6000] loss: 0.780
(func pid=5456) [1,  8000] loss: 0.585

Trial status: 2 TERMINATED | 1 RUNNING | 7 PENDING
Current time: 2026-01-16 22:35:19. Total running time: 4min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00002   RUNNING       256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10            166.714   1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1             64.513   2.30847       0.0988 │
│ train_cifar_01e85_00003   PENDING        32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=5456) [2026-01-16 22:35:19,807 E 5456 5491] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=5456) [1, 10000] loss: 0.468
(func pid=5456) [1, 12000] loss: 0.390
(func pid=5456) [1, 14000] loss: 0.334
(func pid=5456) [1, 16000] loss: 0.292
(func pid=5456) [1, 18000] loss: 0.260
(func pid=5456) [1, 20000] loss: 0.234
Trial status: 2 TERMINATED | 1 RUNNING | 7 PENDING
Current time: 2026-01-16 22:35:49. Total running time: 5min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00002   RUNNING       256     32   0.0311106                2                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10            166.714   1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1             64.513   2.30847       0.0988 │
│ train_cifar_01e85_00003   PENDING        32      8   0.000513478             16                                                    │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00002 completed after 1 iterations at 2026-01-16 22:35:56. Total running time: 5min 7s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00002 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000000 │
│ time_this_iter_s                                  63.67258 │
│ time_total_s                                      63.67258 │
│ training_iteration                                       1 │
│ accuracy                                            0.1026 │
│ loss                                               2.31367 │
╰────────────────────────────────────────────────────────────╯
(func pid=5456) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00002_2_batch_size=2,l1=256,l2=32,lr=0.0311_2026-01-16_22-30-49/checkpoint_000000)

Trial train_cifar_01e85_00003 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00003 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                    16 │
│ device                                      cuda │
│ l1                                            32 │
│ l2                                             8 │
│ lr                                       0.00051 │
╰──────────────────────────────────────────────────╯
(func pid=5592) [1,  2000] loss: 2.259
(func pid=5592) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00003_3_batch_size=16,l1=32,l2=8,lr=0.0005_2026-01-16_22-30-49/checkpoint_000000)
(func pid=5592) [2,  2000] loss: 1.903

Trial status: 3 TERMINATED | 1 RUNNING | 6 PENDING
Current time: 2026-01-16 22:36:19. Total running time: 5min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00003   RUNNING        32      8   0.000513478             16        1            11.0541   2.04646       0.2233 │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00004   PENDING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00003 completed after 2 iterations at 2026-01-16 22:36:20. Total running time: 5min 31s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00003 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000001 │
│ time_this_iter_s                                   8.77991 │
│ time_total_s                                      19.83405 │
│ training_iteration                                       2 │
│ accuracy                                            0.3262 │
│ loss                                               1.76772 │
╰────────────────────────────────────────────────────────────╯
(func pid=5592) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00003_3_batch_size=16,l1=32,l2=8,lr=0.0005_2026-01-16_22-30-49/checkpoint_000001)

Trial train_cifar_01e85_00004 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00004 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     4 │
│ device                                      cuda │
│ l1                                           256 │
│ l2                                             2 │
│ lr                                       0.00679 │
╰──────────────────────────────────────────────────╯
(func pid=5788) [1,  2000] loss: 2.187
(func pid=5788) [1,  4000] loss: 1.054
(func pid=5788) [1,  6000] loss: 0.700
(func pid=5788) [1,  8000] loss: 0.556

Trial status: 4 TERMINATED | 1 RUNNING | 5 PENDING
Current time: 2026-01-16 22:36:49. Total running time: 6min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00004   RUNNING       256      2   0.00678774               4                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00005   PENDING        32      2   0.00018331              16                                                    │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=5788) [2026-01-16 22:36:51,807 E 5788 5823] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=5788) [1, 10000] loss: 0.461

Trial train_cifar_01e85_00004 completed after 1 iterations at 2026-01-16 22:36:58. Total running time: 6min 9s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00004 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000000 │
│ time_this_iter_s                                  33.50241 │
│ time_total_s                                      33.50241 │
│ training_iteration                                       1 │
│ accuracy                                            0.0983 │
│ loss                                               2.30523 │
╰────────────────────────────────────────────────────────────╯
(func pid=5788) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00004_4_batch_size=4,l1=256,l2=2,lr=0.0068_2026-01-16_22-30-49/checkpoint_000000)

Trial train_cifar_01e85_00005 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00005 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                    16 │
│ device                                      cuda │
│ l1                                            32 │
│ l2                                             2 │
│ lr                                       0.00018 │
╰──────────────────────────────────────────────────╯
(func pid=5920) [1,  2000] loss: 2.330
(func pid=5920) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00005_5_batch_size=16,l1=32,l2=2,lr=0.0002_2026-01-16_22-30-49/checkpoint_000000)
(func pid=5920) [2,  2000] loss: 2.214

Trial status: 5 TERMINATED | 1 RUNNING | 4 PENDING
Current time: 2026-01-16 22:37:19. Total running time: 6min 30s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00005   RUNNING        32      2   0.00018331              16        1            10.7912   2.25213       0.152  │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00006   PENDING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00005 completed after 2 iterations at 2026-01-16 22:37:22. Total running time: 6min 33s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00005 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000001 │
│ time_this_iter_s                                   8.78223 │
│ time_total_s                                      19.57339 │
│ training_iteration                                       2 │
│ accuracy                                            0.2067 │
│ loss                                               2.16637 │
╰────────────────────────────────────────────────────────────╯
(func pid=5920) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00005_5_batch_size=16,l1=32,l2=2,lr=0.0002_2026-01-16_22-30-49/checkpoint_000001)

Trial train_cifar_01e85_00006 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00006 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     4 │
│ device                                      cuda │
│ l1                                           256 │
│ l2                                             8 │
│ lr                                       0.00712 │
╰──────────────────────────────────────────────────╯
(func pid=6116) [1,  2000] loss: 2.152
(func pid=6116) [1,  4000] loss: 1.031
(func pid=6116) [1,  6000] loss: 0.673

Trial status: 6 TERMINATED | 1 RUNNING | 3 PENDING
Current time: 2026-01-16 22:37:49. Total running time: 7min 0s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00006   RUNNING       256      8   0.00712426               4                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=6116) [1,  8000] loss: 0.509
(func pid=6116) [2026-01-16 22:37:53,814 E 6116 6151] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(func pid=6116) [1, 10000] loss: 0.410
(func pid=6116) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00006_6_batch_size=4,l1=256,l2=8,lr=0.0071_2026-01-16_22-30-49/checkpoint_000000)
(func pid=6116) [2,  2000] loss: 2.010
(func pid=6116) [2,  4000] loss: 1.003
(func pid=6116) [2,  6000] loss: 0.689
Trial status: 6 TERMINATED | 1 RUNNING | 3 PENDING
Current time: 2026-01-16 22:38:20. Total running time: 7min 31s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00006   RUNNING       256      8   0.00712426               4        1            33.4404   2.04758       0.2143 │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00007   PENDING         4      2   0.00163636              16                                                    │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=6116) [2,  8000] loss: 0.504
(func pid=6116) [2, 10000] loss: 0.409
(func pid=6116) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00006_6_batch_size=4,l1=256,l2=8,lr=0.0071_2026-01-16_22-30-49/checkpoint_000001)

Trial train_cifar_01e85_00006 completed after 2 iterations at 2026-01-16 22:38:31. Total running time: 7min 42s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00006 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000001 │
│ time_this_iter_s                                  31.65389 │
│ time_total_s                                      65.09426 │
│ training_iteration                                       2 │
│ accuracy                                            0.2196 │
│ loss                                               2.06751 │
╰────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00007 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00007 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                    16 │
│ device                                      cuda │
│ l1                                             4 │
│ l2                                             2 │
│ lr                                       0.00164 │
╰──────────────────────────────────────────────────╯
(func pid=6316) [1,  2000] loss: 2.317

Trial train_cifar_01e85_00007 completed after 1 iterations at 2026-01-16 22:38:46. Total running time: 7min 57s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00007 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000000 │
│ time_this_iter_s                                  10.90175 │
│ time_total_s                                      10.90175 │
│ training_iteration                                       1 │
│ accuracy                                            0.1004 │
│ loss                                               2.30289 │
╰────────────────────────────────────────────────────────────╯
(func pid=6316) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00007_7_batch_size=16,l1=4,l2=2,lr=0.0016_2026-01-16_22-30-49/checkpoint_000000)

Trial status: 8 TERMINATED | 2 PENDING
Current time: 2026-01-16 22:38:50. Total running time: 8min 1s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00006   TERMINATED    256      8   0.00712426               4        2            65.0943   2.06751       0.2196 │
│ train_cifar_01e85_00007   TERMINATED      4      2   0.00163636              16        1            10.9018   2.30289       0.1004 │
│ train_cifar_01e85_00008   PENDING       128      4   0.000264114              8                                                    │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00008 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00008 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     8 │
│ device                                      cuda │
│ l1                                           128 │
│ l2                                             4 │
│ lr                                       0.00026 │
╰──────────────────────────────────────────────────╯
(func pid=6446) [1,  2000] loss: 2.300
(func pid=6446) [1,  4000] loss: 1.054
(func pid=6446) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00008_8_batch_size=8,l1=128,l2=4,lr=0.0003_2026-01-16_22-30-49/checkpoint_000000)
(func pid=6446) [2,  2000] loss: 1.832
(func pid=6446) [2026-01-16 22:39:17,827 E 6446 6481] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14

Trial status: 8 TERMINATED | 1 RUNNING | 1 PENDING
Current time: 2026-01-16 22:39:20. Total running time: 8min 31s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00008   RUNNING       128      4   0.000264114              8        1            18.4361   1.89157       0.2756 │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00006   TERMINATED    256      8   0.00712426               4        2            65.0943   2.06751       0.2196 │
│ train_cifar_01e85_00007   TERMINATED      4      2   0.00163636              16        1            10.9018   2.30289       0.1004 │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=6446) [2,  4000] loss: 0.874
(func pid=6446) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00008_8_batch_size=8,l1=128,l2=4,lr=0.0003_2026-01-16_22-30-49/checkpoint_000001)
(func pid=6446) [3,  2000] loss: 1.661
(func pid=6446) [3,  4000] loss: 0.810
(func pid=6446) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00008_8_batch_size=8,l1=128,l2=4,lr=0.0003_2026-01-16_22-30-49/checkpoint_000002)
(func pid=6446) [4,  2000] loss: 1.551
Trial status: 8 TERMINATED | 1 RUNNING | 1 PENDING
Current time: 2026-01-16 22:39:50. Total running time: 9min 1s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00008   RUNNING       128      4   0.000264114              8        3            51.456    1.57389       0.3917 │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00006   TERMINATED    256      8   0.00712426               4        2            65.0943   2.06751       0.2196 │
│ train_cifar_01e85_00007   TERMINATED      4      2   0.00163636              16        1            10.9018   2.30289       0.1004 │
│ train_cifar_01e85_00009   PENDING         4      1   0.092961                 8                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(func pid=6446) [4,  4000] loss: 0.764

Trial train_cifar_01e85_00008 completed after 4 iterations at 2026-01-16 22:39:58. Total running time: 9min 9s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00008 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000003 │
│ time_this_iter_s                                  16.59541 │
│ time_total_s                                      68.05143 │
│ training_iteration                                       4 │
│ accuracy                                            0.4237 │
│ loss                                               1.51286 │
╰────────────────────────────────────────────────────────────╯
(func pid=6446) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00008_8_batch_size=8,l1=128,l2=4,lr=0.0003_2026-01-16_22-30-49/checkpoint_000003)

Trial train_cifar_01e85_00009 started with configuration:
╭──────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00009 config             │
├──────────────────────────────────────────────────┤
│ batch_size                                     8 │
│ device                                      cuda │
│ l1                                             4 │
│ l2                                             1 │
│ lr                                       0.09296 │
╰──────────────────────────────────────────────────╯
(func pid=6774) [1,  2000] loss: 2.329
(func pid=6774) [1,  4000] loss: 1.164

Trial status: 9 TERMINATED | 1 RUNNING
Current time: 2026-01-16 22:40:20. Total running time: 9min 31s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00009   RUNNING         4      1   0.092961                 8                                                    │
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00006   TERMINATED    256      8   0.00712426               4        2            65.0943   2.06751       0.2196 │
│ train_cifar_01e85_00007   TERMINATED      4      2   0.00163636              16        1            10.9018   2.30289       0.1004 │
│ train_cifar_01e85_00008   TERMINATED    128      4   0.000264114              8        4            68.0514   1.51286       0.4237 │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Trial train_cifar_01e85_00009 completed after 1 iterations at 2026-01-16 22:40:21. Total running time: 9min 32s
╭────────────────────────────────────────────────────────────╮
│ Trial train_cifar_01e85_00009 result                       │
├────────────────────────────────────────────────────────────┤
│ checkpoint_dir_name                      checkpoint_000000 │
│ time_this_iter_s                                  18.87604 │
│ time_total_s                                      18.87604 │
│ training_iteration                                       1 │
│ accuracy                                             0.101 │
│ loss                                               2.31448 │
╰────────────────────────────────────────────────────────────╯
2026-01-16 22:40:21,750 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48' in 0.0102s.

Trial status: 10 TERMINATED
Current time: 2026-01-16 22:40:21. Total running time: 9min 32s
Logical resource usage: 2.0/16 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:A10G)
Current best trial: 01e85_00000 with loss=1.1553119671106338 and params={'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                status         l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_cifar_01e85_00000   TERMINATED     16     16   0.000333728              8       10           166.714    1.15531       0.5875 │
│ train_cifar_01e85_00001   TERMINATED      1    256   0.00338356               2        1            64.513    2.30847       0.0988 │
│ train_cifar_01e85_00002   TERMINATED    256     32   0.0311106                2        1            63.6726   2.31367       0.1026 │
│ train_cifar_01e85_00003   TERMINATED     32      8   0.000513478             16        2            19.834    1.76772       0.3262 │
│ train_cifar_01e85_00004   TERMINATED    256      2   0.00678774               4        1            33.5024   2.30523       0.0983 │
│ train_cifar_01e85_00005   TERMINATED     32      2   0.00018331              16        2            19.5734   2.16637       0.2067 │
│ train_cifar_01e85_00006   TERMINATED    256      8   0.00712426               4        2            65.0943   2.06751       0.2196 │
│ train_cifar_01e85_00007   TERMINATED      4      2   0.00163636              16        1            10.9018   2.30289       0.1004 │
│ train_cifar_01e85_00008   TERMINATED    128      4   0.000264114              8        4            68.0514   1.51286       0.4237 │
│ train_cifar_01e85_00009   TERMINATED      4      1   0.092961                 8        1            18.876    2.31448       0.101  │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Best trial config: {'l1': 16, 'l2': 16, 'lr': 0.00033372805612201707, 'batch_size': 8, 'device': 'cuda'}
Best trial final validation loss: 1.1553119671106338
Best trial final validation accuracy: 0.5875
(func pid=6774) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2026-01-16_22-30-48/train_cifar_01e85_00009_9_batch_size=8,l1=4,l2=1,lr=0.0930_2026-01-16_22-30-49/checkpoint_000000)
Best trial test set accuracy: 0.5935

Results#

Your Ray Tune trial summary output looks something like this. The text table summarizes the validation performance of the trials and highlights the best hyperparameter configuration:

Number of trials: 10/10 (10 TERMINATED)
+-----+--------------+------+------+-------------+--------+---------+------------+
| ... |   batch_size |   l1 |   l2 |          lr |   iter |    loss |   accuracy |
|-----+--------------+------+------+-------------+--------+---------+------------|
| ... |            2 |    1 |  256 | 0.000668163 |      1 | 2.31479 |     0.0977 |
| ... |            4 |   64 |    8 | 0.0331514   |      1 | 2.31605 |     0.0983 |
| ... |            4 |    2 |    1 | 0.000150295 |      1 | 2.30755 |     0.1023 |
| ... |           16 |   32 |   32 | 0.0128248   |     10 | 1.66912 |     0.4391 |
| ... |            4 |    8 |  128 | 0.00464561  |      2 | 1.7316  |     0.3463 |
| ... |            8 |  256 |    8 | 0.00031556  |      1 | 2.19409 |     0.1736 |
| ... |            4 |   16 |  256 | 0.00574329  |      2 | 1.85679 |     0.3368 |
| ... |            8 |    2 |    2 | 0.00325652  |      1 | 2.30272 |     0.0984 |
| ... |            2 |    2 |    2 | 0.000342987 |      2 | 1.76044 |     0.292  |
| ... |            4 |   64 |   32 | 0.003734    |      8 | 1.53101 |     0.4761 |
+-----+--------------+------+------+-------------+--------+---------+------------+

Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
Best trial final validation loss: 1.5310075663924216
Best trial final validation accuracy: 0.4761
Best trial test set accuracy: 0.4737

Most trials stopped early to conserve resources. The best performing trial achieved a validation accuracy of approximately 47%, which the test set confirms.

Observability#

Monitoring is critical when running large-scale experiments. Ray provides a dashboard that lets you view the status of your trials, check cluster resource use, and inspect logs in real time.

For debugging, Ray also offers distributed debugging tools that let you attach a debugger to running trials across the cluster.

Conclusion#

In this tutorial, you learned how to tune the hyperparameters of a PyTorch model using Ray Tune. You saw how to integrate Ray Tune into your PyTorch training loop, define a search space for your hyperparameters, use an efficient scheduler like ASHAScheduler to terminate low-performing trials early, save checkpoints and report metrics to Ray Tune, and run the hyperparameter search and analyze the results.

Ray Tune makes it straightforward to scale your experiments from a single machine to a large cluster, helping you find the best model configuration efficiently.

Further reading#

Total running time of the script: (9 minutes 47.178 seconds)