作者: Humbulani Ndou
创建日期 2025/03/17
最后修改 2025/03/17
描述: 使用门控残差网络和变量选择网络进行预测并进行超参数调优。
以下示例通过使用 Autokeras 和 KerasTuner 集成超参数调优,扩展了脚本 `structured_data/classification_with_grn_and_vsn.py`。将在相关代码部分详细描述这两个包中使用的具体 API。
本示例展示了门控残差网络 (GRN) 和变量选择网络 (VSN) 在结构化数据分类中的应用,这些网络由 Bryan Lim 等人在《用于可解释多视界时间序列预测的 Temporal Fusion Transformers (TFT)》一文中提出。GRN 赋予模型灵活性,仅在需要的地方应用非线性处理。VSN 允许模型柔性地移除任何可能对性能产生负面影响的不必要的噪声输入。这些技术共同有助于提高深度神经网络模型的学习能力。
请注意,本示例仅实现了文中描述的 GRN 和 VSN 组件,而非整个 TFT 模型,因为 GRN 和 VSN 自身对于结构化数据学习任务就很有用。
要运行代码,你需要使用 TensorFlow 2.3 或更高版本。
我们的数据集由克利夫兰诊所基金会提供,用于心脏病研究。它是一个包含 303 行的 CSV 文件。每行包含一个患者(一个样本)的信息,每列描述患者的一个属性(一个特征)。我们使用这些特征来预测患者是否患有心脏病(二分类)。
以下是每个特征的描述
列 | 描述 | 特征类型 |
---|---|---|
年龄 | 年龄(岁) | 数值型 |
性别 | (1 = 男性;0 = 女性) | 类别型 |
CP | 胸痛类型 (0, 1, 2, 3, 4) | 类别型 |
Trestbpd | 静息血压(入院时,单位:mmHg) | 数值型 |
Chol | 血清胆固醇(单位:mg/dl) | 数值型 |
FBS | 空腹血糖 > 120 mg/dl(1 = 是;0 = 否) | 类别型 |
RestECG | 静息心电图结果 (0, 1, 2) | 类别型 |
Thalach | 达到的最大心率 | 数值型 |
Exang | 运动诱发心绞痛(1 = 是;0 = 否) | 类别型 |
Oldpeak | 运动相对于静息引起的 ST 段压低 | 数值型 |
Slope | 运动峰值 ST 段的斜率 | 数值型 |
CA | 荧光透视着色的主要血管数量 (0-3) | 数值型和类别型 |
Thal | 3 = 正常;6 = 固定缺损;7 = 可逆缺损 | 类别型 |
目标 | 心脏病诊断(1 = 是;0 = 否) | 目标 |
import os
import subprocess
import tarfile
import numpy as np
import pandas as pd
import tree
from typing import Optional, Union
os.environ["KERAS_BACKEND"] = "tensorflow" # or jax, or torch
# Keras imports
import keras
from keras import layers
# KerasTuner imports
import keras_tuner
from keras_tuner import HyperParameters
# AutoKeras imports
import autokeras as ak
from autokeras.utils import utils, types
让我们下载数据并将其加载到 Pandas 数据框中
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
数据集包含 303 个样本,每个样本有 14 列(13 个特征,加上目标标签)
dataframe.shape
(303, 14)
以下是部分样本的预览
dataframe.head()
年龄 | 性别 | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | 目标 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | 固定缺损 | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | 正常 | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | 可逆缺损 | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | 正常 | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | 正常 | 0 |
最后一列“target”表示患者是否患有心脏病(1)或未患病(0)。
让我们将数据分割成训练集和验证集
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
print(
f"Using {len(train_dataframe)} samples for training "
f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation
在这里,我们定义数据集的元数据,这对于读取和解析数据以形成输入特征,以及根据特征类型对输入特征进行编码非常有用。
COLUMN_NAMES = [
"age",
"sex",
"cp",
"trestbps",
"chol",
"fbs",
"restecg",
"thalach",
"exang",
"oldpeak",
"slope",
"ca",
"thal",
"target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
feature_name: sorted(
[
# Integer categorcal must be int and string must be str
value if dataframe[feature_name].dtype == "int64" else str(value)
for value in list(dataframe[feature_name].unique())
]
)
for feature_name in COLUMN_NAMES
if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)
以下特征是编码为整数的类别型特征
性别
cp
fbs
restecg
exang
ca
我们将使用独热编码对这些特征进行编码。这里有两种选择
对于本示例,我们想要一个简单的解决方案来处理推理时超出范围的输入,因此我们将使用 `IntegerLookup()`。
我们还有一个编码为字符串的类别型特征:`thal`。我们将创建所有可能特征的索引,并使用 `StringLookup()` 层进行输出编码。
最后,以下特征是连续的数值型特征
年龄
trestbps
chol
thalach
oldpeak
slope
对于这些特征中的每一个,我们将使用 `Normalization()` 层来确保每个特征的均值为 0,标准差为 1。
下面,我们定义一个效用函数来执行这些操作
# Tensorflow required for tf.data.Dataset
import tensorflow as tf
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
for f in features:
if f in CATEGORICAL_FEATURES_WITH_VOCABULARY:
# Create a lookup to convert a string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
cls = (
layers.StringLookup
if features[f].dtype == "string"
else layers.IntegerLookup
)
features[f] = cls(
vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f],
mask_token=None,
num_oov_indices=0,
output_mode="binary",
)(features[f])
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
return dict(features), target
让我们为每个数据框生成 `tf.data.Dataset` 对象
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = (
tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
.map(encode_categorical)
.shuffle(buffer_size=len(dataframe))
)
return ds
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
每个 `Dataset` 生成一个元组 `(input, target)`,其中 `input` 是一个特征字典,`target` 的值为 0 或 1
for x, y in train_ds.take(1):
print("Input:", x)
print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=37>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=120>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=215>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 0, 0])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=170>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 0, 0])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>}
Target: tf.Tensor(0, shape=(), dtype=int64)
让我们对数据集进行批处理
train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)
在这里,我们对 Autokeras `Graph` 进行子类化
class Graph(ak.graph.Graph):
def build(self, hp):
"""Build the HyperModel into a Keras Model."""
keras_nodes = {}
keras_input_nodes = []
for node in self.inputs:
node_id = self._node_to_id[node]
input_node = node.build_node(hp)
output_node = node.build(hp, input_node)
keras_input_nodes.append(input_node)
keras_nodes[node_id] = output_node
for block in self.blocks:
temp_inputs = (
{
n.name: keras_nodes[self._node_to_id[n]]
for n in block.inputs
if isinstance(n, ak.Input)
}
if isinstance(block.inputs[0], ak.Input)
else [keras_nodes[self._node_to_id[n]] for n in block.inputs]
)
outputs = tree.flatten(block.build(hp, inputs=temp_inputs))
for n, o in zip(block.outputs, outputs):
keras_nodes[self._node_to_id[n]] = o
model = keras.models.Model(
keras_input_nodes,
[
keras_nodes[self._node_to_id[output_node]]
for output_node in self.outputs
],
)
return self._compile_keras_model(hp, model)
def _compile_keras_model(self, hp, model):
# Specify hyperparameters from compile(...)
optimizer_name = hp.Choice(
"optimizer",
["adam", "sgd"],
default="adam",
)
learning_rate = hp.Choice(
"learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
)
if optimizer_name == "adam":
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
elif optimizer_name == "sgd":
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(
optimizer=optimizer,
metrics=self._get_metrics(),
loss=self._get_loss(),
)
return model
在这里,我们对 Autokeras `Input` 节点对象进行子类化,并将其 `dtype` 属性从 `None` 重写为用户提供的值。我们还重写 `build_node` 方法,以便对 `Inputs` 层使用用户提供的名称。
class Input(ak.Input):
def __init__(self, dtype, name=None, **kwargs):
super().__init__(name=name, **kwargs)
# Override dtype to a user dtype value
self.dtype = dtype
self.name = name
def build_node(self, hp):
return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)
在这里,我们对 Autokeras `ClassificationHead` 进行子类化,并重写 `__init__` 方法,同时添加 `get_expected_shape` 方法来推断标签的形状。我们移除了预处理功能,因为我们更倾向于手动执行此类操作。
class ClassifierHead(ak.ClassificationHead):
def __init__(
self,
num_classes: Optional[int] = None,
multi_label: bool = False,
loss: Optional[types.LossType] = None,
metrics: Optional[types.MetricsType] = None,
dropout: Optional[float] = None,
**kwargs,
):
self.num_classes = num_classes
self.multi_label = multi_label
self.dropout = dropout
if metrics is None:
metrics = ["accuracy"]
if loss is None:
loss = self.infer_loss()
ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs)
self.shape = self.get_expected_shape()
def get_expected_shape(self):
# Compute expected shape from num_classes.
if self.num_classes == 2 and not self.multi_label:
return [1]
return [self.num_classes]
这是在脚本 `structured_data/classification_with_grn_vsn.py` 中定义的 Keras 层。有关此层的更多详细信息可在相关脚本中找到。
class GatedLinearUnit(layers.Layer):
def __init__(self, num_units, activation, **kwargs):
super().__init__(**kwargs)
self.linear = layers.Dense(num_units)
self.sigmoid = layers.Dense(num_units, activation=activation)
def call(self, inputs):
return self.linear(inputs) * self.sigmoid(inputs)
def build(self):
self.built = True
这是在脚本 `structured_data/classification_with_grn_vsn.py` 中定义的 Keras 层。有关此层的更多详细信息可在相关脚本中找到。
class GatedResidualNetwork(layers.Layer):
def __init__(
self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs
):
super().__init__(**kwargs)
self.num_units = num_units
self.use_layernorm = use_layernorm
self.elu_dense = layers.Dense(num_units, activation=activation)
self.linear_dense = layers.Dense(num_units)
self.dropout = layers.Dropout(dropout_rate)
self.gated_linear_unit = GatedLinearUnit(num_units, activation)
self.layer_norm = layers.LayerNormalization()
self.project = layers.Dense(num_units)
def call(self, inputs, hp):
x = self.elu_dense(inputs)
x = self.linear_dense(x)
x = self.dropout(x)
if inputs.shape[-1] != self.num_units:
inputs = self.project(inputs)
x = inputs + self.gated_linear_unit(x)
use_layernorm = self.use_layernorm
if use_layernorm is None:
use_layernorm = hp.Boolean("use_layernorm", default=True)
if use_layernorm:
x = self.layer_norm(x)
return x
def build(self):
self.built = True
我们将以下 Keras 层转换为 Autokeras `Block`,以包含可调优的超参数。有关编写自定义 `Block` 的信息,请参阅 Autokeras blocks API。
class VariableSelection(ak.Block):
def __init__(
self,
num_units: Optional[Union[int, HyperParameters.Choice]] = None,
dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None,
activation: Optional[Union[str, HyperParameters.Choice]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.dropout = utils.get_hyperparameter(
dropout_rate,
HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
float,
)
self.num_units = utils.get_hyperparameter(
num_units,
HyperParameters().Choice(
"num_units", [16, 32, 64, 128, 256, 512, 1024], default=16
),
int,
)
self.activation = utils.get_hyperparameter(
activation,
HyperParameters().Choice(
"vsn_activation", ["sigmoid", "elu"], default="sigmoid"
),
str,
)
def build(self, hp, inputs):
num_units = utils.add_to_hp(self.num_units, hp, "num_units")
dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate")
activation = utils.add_to_hp(self.activation, hp, "activation")
concat_inputs = []
# Project the features to 'num_units' dimension
for input_ in inputs:
if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
concat_inputs.append(
keras.layers.Dense(units=num_units)(inputs[input_])
)
else:
# Create a Normalization layer for our feature
normalizer = layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = train_ds.map(lambda x, y: x[input_]).map(
lambda x: keras.ops.expand_dims(x, -1)
)
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
normal_feature = normalizer(inputs[input_])
concat_inputs.append(
keras.layers.Dense(units=num_units)(normal_feature)
)
v = layers.concatenate(concat_inputs)
v = GatedResidualNetwork(
num_units=num_units, dropout_rate=dropout_rate, activation=activation
)(v, hp=hp)
v = keras.ops.expand_dims(
layers.Dense(units=len(inputs), activation=activation)(v), axis=-1
)
x = []
x += [
GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp)
for i in concat_inputs
]
x = keras.ops.stack(x, axis=1)
return keras.ops.squeeze(
keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
)
# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
inputs = {
f: (
Input(
name=f,
shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),),
dtype="int64",
)
if f in CATEGORICAL_FEATURES_WITH_VOCABULARY
else Input(name=f, shape=(1,), dtype="float32")
)
for f in FEATURE_NAMES
}
return inputs
在这里,我们使用 Autokeras `Functional` API 构建一个 `Blocks` 网络,该网络将最终构建为 KerasTuner `HyperModel` 和 Keras `Model`。
class MyHyperModel(keras_tuner.HyperModel):
def build(self, hp):
inputs = create_model_inputs()
features = VariableSelection()(inputs)
outputs = ClassifierHead(num_classes=2, multi_label=False)(features)
model = Graph(inputs=inputs, outputs=outputs)
model = model.build(hp)
return model
def fit(self, hp, model, *args, **kwargs):
return model.fit(
*args,
# Tune whether to shuffle the data in each epoch.
shuffle=hp.Boolean("shuffle"),
**kwargs,
)
我们使用 `RandomSearch Tuner` 在搜索空间中搜索超参数。我们还显示搜索空间。
print("Start training and searching for the best model...")
tuner = keras_tuner.RandomSearch(
MyHyperModel(),
objective="val_accuracy",
max_trials=3,
overwrite=True,
directory="my_dir",
project_name="tune_hypermodel",
)
# Show the search space summary
print("Tuner search space summary:\n")
tuner.search_space_summary()
# Search for best model
tuner.search(train_ds, epochs=2, validation_data=val_ds)
Trial 3 Complete [00h 00m 16s]
val_accuracy: 0.8032786846160889
Best val_accuracy So Far: 0.8032786846160889
Total elapsed time: 00h 00m 34s
# Get the top model.
models = tuner.get_best_models(num_models=1)
best_model = models[0]
best_model.summary()
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 346 variables.
saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ age (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ chol (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ oldpeak │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ slope (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ thalach │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ trestbps │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32 │ (None, 1) │ 0 │ age[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ca (InputLayer) │ (None, 4) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_2 │ (None, 1) │ 0 │ chol[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cp (InputLayer) │ (None, 5) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ exang (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ fbs (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_6 │ (None, 1) │ 0 │ oldpeak[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ restecg │ (None, 3) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ sex (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_9 │ (None, 1) │ 0 │ slope[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ thal (InputLayer) │ (None, 5) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_11 │ (None, 1) │ 0 │ thalach[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_12 │ (None, 1) │ 0 │ trestbps[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization │ (None, 1) │ 3 │ cast_to_float32[… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_1 │ (None, 4) │ 0 │ ca[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_1 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_3 │ (None, 5) │ 0 │ cp[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_4 │ (None, 2) │ 0 │ exang[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_5 │ (None, 2) │ 0 │ fbs[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_2 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_7 │ (None, 3) │ 0 │ restecg[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_8 │ (None, 2) │ 0 │ sex[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_3 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_10 │ (None, 5) │ 0 │ thal[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_4 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_5 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense (Dense) │ (None, 16) │ 32 │ normalization[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_1 (Dense) │ (None, 16) │ 80 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_2 (Dense) │ (None, 16) │ 32 │ normalization_1[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_3 (Dense) │ (None, 16) │ 96 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_4 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_5 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_6 (Dense) │ (None, 16) │ 32 │ normalization_2[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_7 (Dense) │ (None, 16) │ 64 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_8 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_9 (Dense) │ (None, 16) │ 32 │ normalization_3[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_10 (Dense) │ (None, 16) │ 96 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_11 (Dense) │ (None, 16) │ 32 │ normalization_4[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_12 (Dense) │ (None, 16) │ 32 │ normalization_5[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ concatenate │ (None, 208) │ 0 │ dense[0][0], │ │ (Concatenate) │ │ │ dense_1[0][0], │ │ │ │ │ dense_2[0][0], │ │ │ │ │ dense_3[0][0], │ │ │ │ │ dense_4[0][0], │ │ │ │ │ dense_5[0][0], │ │ │ │ │ dense_6[0][0], │ │ │ │ │ dense_7[0][0], │ │ │ │ │ dense_8[0][0], │ │ │ │ │ dense_9[0][0], │ │ │ │ │ dense_10[0][0], │ │ │ │ │ dense_11[0][0], │ │ │ │ │ dense_12[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 7,536 │ concatenate[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_18 (Dense) │ (None, 13) │ 221 │ gated_residual_n… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ expand_dims │ (None, 13, 1) │ 0 │ dense_18[0][0] │ │ (ExpandDims) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_1[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_2[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_3[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_4[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_5[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_6[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_7[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_8[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_9[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_10[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_11[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_12[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transpose │ (None, 1, 13) │ 0 │ expand_dims[0][0] │ │ (Transpose) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stack (Stack) │ (None, 13, 16) │ 0 │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ matmul (Matmul) │ (None, 1, 16) │ 0 │ transpose[0][0], │ │ │ │ │ stack[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ squeeze (Squeeze) │ (None, 16) │ 0 │ matmul[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout_14 │ (None, 16) │ 0 │ squeeze[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_84 (Dense) │ (None, 1) │ 17 │ dropout_14[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ classifier_head_1 │ (None, 1) │ 0 │ dense_84[0][0] │ │ (Activation) │ │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 23,024 (89.96 KB)
Trainable params: 23,006 (89.87 KB)
Non-trainable params: 18 (96.00 B)
要获取新样本的预测结果,你只需调用 `model.predict()`。只需要做两件事
sample = {
"age": 60,
"sex": 1,
"cp": 1,
"trestbps": 145,
"chol": 233,
"fbs": 1,
"restecg": 2,
"thalach": 150,
"exang": 0,
"oldpeak": 2.3,
"slope": 3,
"ca": 0,
"thal": "fixed",
}
# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
# Create a list of zeros with the same length as categories
encoding = [0] * len(cat)
# Find the index of category_value in categories and set the corresponding position to 1
if cat_value in cat:
encoding[cat.index(cat_value)] = 1
return encoding
for name, value in sample.items():
if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
sample.update(
{
name: get_cat_encoding(
CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
)
}
)
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = best_model.predict(input_dict)
print(
f"This particular patient had a {100 * predictions[0][0]:.1f} "
"percent probability of having a heart disease, "
"as evaluated by our model."
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 162ms/step
This particular patient had a 28.1 percent probability of having a heart disease, as evaluated by our model.