RematScope 类keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)
在Keras中启用重计算的上下文管理器。
重计算(梯度检查点)通过在反向传播期间重新计算中间激活值来以计算换取内存。这对于在内存限制下训练大型模型或大批量大小特别有用。
这应该在初始化层时使用(例如,layer(input))。重计算在执行时应用,而不是在创建时应用。
参数
"full":全局应用于所有支持的操作。"activations":应用于包含keras.activations(例如,Dense(..., activation=relu))的任何层的激活。"larger_than":应用于输出大小大于output_size_threshold的层。"list_of_layers":应用于指定层名称列表。None:禁用重计算。"larger_than"模式的输出大小阈值。输出大于此阈值的层将被重计算。默认为1024。"list_of_layers"模式的层名称列表。默认为空列表。示例
使用“list_of_layers”模式
from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")
# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
使用具有特定输出大小阈值的“larger_than”模式
with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
用于精细控制的嵌套作用域
with RematScope(mode="full"):
# Create layers
layer1 = keras.layers.Dense(128, activation='relu')
output1 = layer1(input_tensor) # layer1 is fully rematerialized
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output2 = layer2(output1) # layer2 is conditionally rematerialized
# if output > 512