Keras 3 API 文档 / 重计算 / RematScope

RematScope

[源代码]

RematScope

keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)

Keras 中用于启用重计算的上下文管理器。

重计算(梯度检查点)通过在反向传播过程中重新计算中间激活来牺牲计算换取内存。这对于在有限内存约束下训练大型模型或大批量数据特别有用。

这应该在初始化层时使用(例如,layer(input))。重计算在执行时应用,而不是创建时应用。

参数

  • mode:要应用的重计算模式。可选值有
    • "full":全局应用于所有支持的操作。
    • "activations":应用于包含 keras.activations 的任何层(例如 Dense(..., activation=relu))中的激活函数。
    • "larger_than":应用于输出大小大于 output_size_threshold 的层。
    • "list_of_layers":应用于指定的层名称列表。
    • None:禁用重计算。
  • output_size_threshold"larger_than" 模式的输出大小阈值。输出大于此阈值的层将进行重计算。默认值为 1024
  • layer_names"list_of_layers" 模式的层名称列表。默认值为空列表。

示例

使用 "list_of_layers" 模式

from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
    layer1 = keras.layers.Dense(128, name="dense_1")
    layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
    layer3 = keras.layers.Dense(64, name="dense_2")
    # Only layer1 and layer2 will apply rematerialization
    output1 = layer1(input_tensor)
    output2 = layer2(output1)
    output3 = layer3(output2)

使用带有特定输出大小阈值的 "larger_than" 模式

with RematScope(mode="larger_than", output_size_threshold=2048):
    layer = keras.layers.Conv2D(64, (3, 3))
    output = layer(input_tensor)  # Conv2D outputs larger than 2048

嵌套作用域实现精细控制

with RematScope(mode="full"):
    # Create layers
    layer1 = keras.layers.Dense(128, activation='relu')
    output1 = layer1(input_tensor)  # layer1 is fully rematerialized
    with RematScope(mode="larger_than", output_size_threshold=512):
        layer2 = keras.layers.Conv2D(32, (3, 3))
        output2 = layer2(output1) # layer2 is conditionally rematerialized
        # if output > 512