Remat

[源代码]

remat 函数

keras.remat(f)

应用重计算来优化函数或层的内存。

重计算是一种以计算换内存的优化技术。在反向传播时,不会存储中间结果(例如激活值),而是会重新计算它们。这降低了峰值内存使用量,但会增加计算时间,从而允许在相同的内存限制下训练更大的模型或使用更大的批次大小。

参数

  • f:一个可调用的函数,将应用重计算。这通常是一个计算密集型操作,其中中间状态可以被重新计算而不是被存储。

返回

一个已包装的函数,它应用重计算。返回的函数定义了一个自定义梯度,确保在反向传播期间,根据需要重新计算前向计算。

示例

from keras import Model
class CustomRematLayer(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.remat_function = remat(self.intermediate_function)

    def intermediate_function(self, x):
        for _ in range(2):
            x = x + x * 0.1  # Simple scaled transformation
        return x

    def call(self, inputs):
        return self.remat_function(inputs)

# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x)  # Custom layer with rematerialization
outputs = layers.Dense(1)(x)

# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")