TorchModuleWrapper
类keras.layers.TorchModuleWrapper(module, name=None, **kwargs)
Torch 模块包装层。
TorchModuleWrapper
是一个包装类,可以将任何 torch.nn.Module
转换为 Keras 层,特别是通过使其参数可被 Keras 追踪。
TorchModuleWrapper
仅与 PyTorch 后端兼容,不能与 TensorFlow 或 JAX 后端一起使用。
参数
torch.nn.Module
实例。如果它是一个 LazyModule
实例,则必须在将实例传递给 TorchModuleWrapper
之前初始化其参数(例如,通过调用它一次)。示例
以下是如何将 TorchModuleWrapper
与普通 PyTorch 模块一起使用的示例。
import torch
import torch.nn as nn
import torch.nn.functional as F
import keras
from keras.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Wrap `torch.nn.Module`s with `TorchModuleWrapper`
# if they contain parameters
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("# Output shape", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)