RematScope
类keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)
Keras 中用于启用重计算的上下文管理器。
重计算(梯度检查点)通过在反向传播过程中重新计算中间激活来牺牲计算换取内存。这对于在有限内存约束下训练大型模型或大批量数据特别有用。
这应该在初始化层时使用(例如,layer(input)
)。重计算在执行时应用,而不是创建时应用。
参数
"full"
:全局应用于所有支持的操作。"activations"
:应用于包含 keras.activations
的任何层(例如 Dense(..., activation=relu)
)中的激活函数。"larger_than"
:应用于输出大小大于 output_size_threshold
的层。"list_of_layers"
:应用于指定的层名称列表。None
:禁用重计算。"larger_than"
模式的输出大小阈值。输出大于此阈值的层将进行重计算。默认值为 1024
。"list_of_layers"
模式的层名称列表。默认值为空列表。示例
使用 "list_of_layers" 模式
from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")
# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
使用带有特定输出大小阈值的 "larger_than" 模式
with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
嵌套作用域实现精细控制
with RematScope(mode="full"):
# Create layers
layer1 = keras.layers.Dense(128, activation='relu')
output1 = layer1(input_tensor) # layer1 is fully rematerialized
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output2 = layer2(output1) # layer2 is conditionally rematerialized
# if output > 512