TF学习之DeepLabv3+代码阅读6(train_utils)

DeepLabv3+代码阅读之train_utils.py

一、get_model_learning_rate()

def get_model_learning_rate(learning_policy,# Learning rate policy for training.
                            base_learning_rate,# The base learning rate for model training.
                            learning_rate_decay_step, # Decay the base learning rate at a fixed step.
                            learning_rate_decay_factor,# The rate to decay the base learning rate.
                            training_number_of_steps,# Number of steps for training.
                            learning_power,# Power used for 'poly' learning policy.
                            slow_start_step,# Training model with small learning rate for the 
                            				# first few steps.
                            slow_start_learning_rate,# The learning rate employed during slow start.
                            slow_start_burnin_type='none'):# The burnin type for the slow start stage. Can be
      													   #`none` which means no burnin or `linear` which 
      													   # means the learning rate increases linearly from 
      													   # slow_start_learning_rate and reaches
      													   # base_learning_rate after slow_start_steps.
  """Gets model's learning rate.

  Computes the model's learning rate for different learning policy.
  Right now, only "step" and "poly" are supported.
  (1) The learning policy for "step" is computed as follows:
    current_learning_rate = base_learning_rate *
      learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
  See tf.train.exponential_decay for details.
  (2) The learning policy for "poly" is computed as follows:
    current_learning_rate = base_learning_rate *
      (1 - global_step / training_number_of_steps) ^ learning_power

  """
  global_step = tf.train.get_or_create_global_step()
  adjusted_global_step = global_step

  if slow_start_burnin_type != 'none':
    adjusted_global_step -= slow_start_step

  if learning_policy == 'step':
    learning_rate = tf.train.exponential_decay(
        base_learning_rate,
        adjusted_global_step,
        learning_rate_decay_step,
        learning_rate_decay_factor,
        staircase=True)
  elif learning_policy == 'poly':
    learning_rate = tf.train.polynomial_decay(
        base_learning_rate,
        adjusted_global_step,
        training_number_of_steps,
        end_learning_rate=0,
        power=learning_power)
  else:
    raise ValueError('Unknown learning policy.')

  adjusted_slow_start_learning_rate = slow_start_learning_rate
  if slow_start_burnin_type == 'linear':
    # Do linear burnin. Increase linearly from slow_start_learning_rate and
    # reach base_learning_rate after (global_step >= slow_start_steps).
    adjusted_slow_start_learning_rate = (
        slow_start_learning_rate +
        (base_learning_rate - slow_st
### DeepLabV3+ 开源代码实现 DeepLabV3+ 的开源实现可以在多个平台上找到,其中最权威的一个是由 TensorFlow 官方维护的版本。该版本不仅提供了完整的模型训练和推理流程,还包含了详细的文档和支持工具[^3]。 #### 项目地址 GitHub 上的官方 TensorFlow 模型库中可以找到 DeepLabV3+ 的具体实现: - **仓库链接**: [TensorFlow Models](https://github.com/tensorflow/models/tree/master/research/deeplab) 此仓库内包含有预训练权重文件以及用于数据准备、模型训练和评估的各种脚本。 #### 主要目录结构说明 为了方便开发者理解和使用,该项目有着清晰合理的目录布局: - `datasets/`: 存放不同类型的图像分割数据集及其处理脚本。 - `experiment/`: 记录实验配置参数与结果分析。 - `utils/`: 提供一些辅助函数来简化操作过程。 - `core/`: 放置核心组件如网络架构定义等重要部分。 此外,在根目录下还有几个重要的 Python 文件用来控制整个工作流,比如`train.py`, `eval.py` 和 `vis.py`分别负责启动训练任务、性能评测及可视化预测效果等功能模块[^1]。 ```python import tensorflow as tf from deeplab import common, model, input_generator # 创建输入管道 dataset = input_generator.get_dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.split, batch_size=FLAGS.train_batch_size, crop_size=[int(sz) for sz in FLAGS.train_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, is_training=True, should_shuffle=True, should_repeat=True) logits, end_points = model.multi_scale_logits( samples[common.IMAGE], model_options=model_options, image_pyramid=image_pyramid, weight_decay=weight_decay, is_training=True, fine_tune_batch_norm=fine_tune_batch_norm) ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值