Keras 3 API 文档 / KerasTuner / 预言机 / Oracle 基类

Oracle 基类

[源代码]

Oracle

keras_tuner.Oracle(
    objective=None,
    max_trials=None,
    hyperparameters=None,
    allow_new_entries=True,
    tune_new_entries=True,
    seed=None,
    max_retries_per_trial=0,
    max_consecutive_failed_trials=3,
)

实现超参数优化算法。

在并行调优设置中,只有一个 Oracle 实例。工作进程将通过 gPRC 调用 Oracle 方法与集中式 Oracle 实例进行通信。

Trial 对象通常用作通过 gPRC 调用传递信息的通信数据包,以在工作进程 Tuner 实例和 Oracle 之间传递信息。例如,Oracle.create_trial() 返回一个 Trial 对象,而 Oracle.end_trial() 在其参数中接受一个 Trial

当同一个 Trial 实例经过 gRPC 调用时,会重建其新副本。当通过调用 Oracle.end_trial() 将工作进程 Tuner 中的 Trial 对象传递回 Oracle 时,这些更改会与 Oracle 中的原始副本同步。

参数

  • objective: 字符串,keras_tuner.Objective 实例或 keras_tuner.Objective 和字符串的列表。如果为字符串,则优化方向(最小化或最大化)将被推断。如果为 keras_tuner.Objective 的列表,我们将最小化所有目标的总和以最小化,减去所有目标的总和以最大化。当 Tuner.run_trial()HyperModel.fit() 返回单个浮点数作为要最小化的目标时,objective 参数是可选的。
  • max_trials: 整数,最多测试的试验(模型配置)总数。请注意,如果搜索空间已用尽,则预言机可能会在测试 max_trial 个模型之前中断搜索。
  • hyperparameters: 可选的 HyperParameters 实例。可用于覆盖(或预先注册)搜索空间中的超参数。
  • tune_new_entries: 布尔值,是否应将超模型请求但未在 hyperparameters 中指定的超参数条目添加到搜索空间中。如果不是,则将使用这些参数的默认值。默认为 True。
  • allow_new_entries: 布尔值,是否允许超模型请求 hyperparameters 中未列出的超参数条目。默认为 True。
  • seed: 整数。随机种子。
  • max_retries_per_trial: 整数。默认为 0。如果试验崩溃或结果无效,则重试 Trial 的最大次数。
  • max_consecutive_failed_trials: 整数。默认为 3。连续失败的 Trial 的最大数量。达到此数量时,搜索将停止。当所有重试均未成功时,Trial 将标记为失败。

[源代码]

wrapped_func 函数

keras_tuner.Oracle.create_trial()

[源代码]

wrapped_func 函数

keras_tuner.Oracle.end_trial()

[源代码]

get_best_trials 方法

Oracle.get_best_trials(num_trials=1)

返回最佳 Trial


[源代码]

get_state 方法

Oracle.get_state()

返回此对象的当前状态。

此方法在 save 期间调用。

返回值

一个字典,其中包含可序列化对象作为状态。


[源代码]

set_state 方法

Oracle.set_state(state)

设置此对象的当前状态。

此方法在 reload 期间调用。

参数

  • state: 一个字典,其中包含序列化对象作为要恢复的状态。

[源代码]

score_trial 方法

Oracle.score_trial(trial)

对已完成的 Trial 进行评分。

此方法可以在子类中重写,以提供一组超参数值的得分。此方法在已完成的 Trial 上从 end_trial 调用。

参数

  • trial: 已完成的 Trial 对象。

[源代码]

populate_space 方法

Oracle.populate_space(trial_id)

使用试验值填充超参数空间。

此方法应在子类中重写,并在 create_trial 中调用,以使用值填充超参数空间。

参数

  • trial_id: 字符串,此试验的 ID。

返回值

一个字典,其中包含键“values”和“status”,其中“values”是参数名称到建议值的映射,而“status”应为“RUNNING”(试验可以正常开始)、“IDLE”(预言机正在等待某些内容,无法创建试验)或“STOPPED”(预言机已完成搜索,不应创建新试验)。


[源代码]

wrapped_func 函数

keras_tuner.Oracle.update_trial()