代码示例 / 结构化数据 / 使用门控残差网络和变量选择网络进行分类及超参数调优

使用门控残差网络和变量选择网络进行分类及超参数调优

作者: Humbulani Ndou
创建日期 2025/03/17
最后修改 2025/03/17
描述: 使用门控残差网络和变量选择网络进行预测并进行超参数调优。

ⓘ 本示例使用 Keras 2

在 Colab 中查看 GitHub 源码


引言

以下示例通过使用 AutokerasKerasTuner 集成超参数调优,扩展了脚本 `structured_data/classification_with_grn_and_vsn.py`。将在相关代码部分详细描述这两个包中使用的具体 API。

本示例展示了门控残差网络 (GRN) 和变量选择网络 (VSN) 在结构化数据分类中的应用,这些网络由 Bryan Lim 等人在《用于可解释多视界时间序列预测的 Temporal Fusion Transformers (TFT)》一文中提出。GRN 赋予模型灵活性,仅在需要的地方应用非线性处理。VSN 允许模型柔性地移除任何可能对性能产生负面影响的不必要的噪声输入。这些技术共同有助于提高深度神经网络模型的学习能力。

请注意,本示例仅实现了文中描述的 GRN 和 VSN 组件,而非整个 TFT 模型,因为 GRN 和 VSN 自身对于结构化数据学习任务就很有用。

要运行代码,你需要使用 TensorFlow 2.3 或更高版本。


数据集

我们的数据集由克利夫兰诊所基金会提供,用于心脏病研究。它是一个包含 303 行的 CSV 文件。每行包含一个患者(一个样本)的信息,每列描述患者的一个属性(一个特征)。我们使用这些特征来预测患者是否患有心脏病(二分类)。

以下是每个特征的描述

描述 特征类型
年龄 年龄(岁) 数值型
性别 (1 = 男性;0 = 女性) 类别型
CP 胸痛类型 (0, 1, 2, 3, 4) 类别型
Trestbpd 静息血压(入院时,单位:mmHg) 数值型
Chol 血清胆固醇(单位:mg/dl) 数值型
FBS 空腹血糖 > 120 mg/dl(1 = 是;0 = 否) 类别型
RestECG 静息心电图结果 (0, 1, 2) 类别型
Thalach 达到的最大心率 数值型
Exang 运动诱发心绞痛(1 = 是;0 = 否) 类别型
Oldpeak 运动相对于静息引起的 ST 段压低 数值型
Slope 运动峰值 ST 段的斜率 数值型
CA 荧光透视着色的主要血管数量 (0-3) 数值型和类别型
Thal 3 = 正常;6 = 固定缺损;7 = 可逆缺损 类别型
目标 心脏病诊断(1 = 是;0 = 否) 目标

设置

import os
import subprocess
import tarfile
import numpy as np
import pandas as pd
import tree
from typing import Optional, Union

os.environ["KERAS_BACKEND"] = "tensorflow"  # or jax, or torch

# Keras imports
import keras
from keras import layers

# KerasTuner imports
import keras_tuner
from keras_tuner import HyperParameters

# AutoKeras imports
import autokeras as ak
from autokeras.utils import utils, types

准备数据

让我们下载数据并将其加载到 Pandas 数据框中

file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)

数据集包含 303 个样本,每个样本有 14 列(13 个特征,加上目标标签)

dataframe.shape
(303, 14)

以下是部分样本的预览

dataframe.head()
年龄 性别 cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal 目标
0 63 1 1 145 233 1 2 150 0 2.3 3 0 固定缺损 0
1 67 1 4 160 286 0 2 108 1 1.5 2 3 正常 1
2 67 1 4 120 229 0 2 129 1 2.6 2 2 可逆缺损 0
3 37 1 3 130 250 0 0 187 0 3.5 3 0 正常 0
4 41 0 2 130 204 0 2 172 0 1.4 1 0 正常 0

最后一列“target”表示患者是否患有心脏病(1)或未患病(0)。

让我们将数据分割成训练集和验证集

val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)

print(
    f"Using {len(train_dataframe)} samples for training "
    f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation

定义数据集元数据

在这里,我们定义数据集的元数据,这对于读取和解析数据以形成输入特征,以及根据特征类型对输入特征进行编码非常有用。

COLUMN_NAMES = [
    "age",
    "sex",
    "cp",
    "trestbps",
    "chol",
    "fbs",
    "restecg",
    "thalach",
    "exang",
    "oldpeak",
    "slope",
    "ca",
    "thal",
    "target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.

CATEGORICAL_FEATURES_WITH_VOCABULARY = {
    feature_name: sorted(
        [
            # Integer categorcal must be int and string must be str
            value if dataframe[feature_name].dtype == "int64" else str(value)
            for value in list(dataframe[feature_name].unique())
        ]
    )
    for feature_name in COLUMN_NAMES
    if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
    CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)

使用 Keras 层进行特征预处理

以下特征是编码为整数的类别型特征

  • 性别
  • cp
  • fbs
  • restecg
  • exang
  • ca

我们将使用独热编码对这些特征进行编码。这里有两种选择

  • 使用 `CategoryEncoding()`,这需要知道输入值的范围,超出范围的输入会报错。
  • 使用 `IntegerLookup()`,它将为输入构建一个查找表,并为未知输入值保留一个输出索引。

对于本示例,我们想要一个简单的解决方案来处理推理时超出范围的输入,因此我们将使用 `IntegerLookup()`。

我们还有一个编码为字符串的类别型特征:`thal`。我们将创建所有可能特征的索引,并使用 `StringLookup()` 层进行输出编码。

最后,以下特征是连续的数值型特征

  • 年龄
  • trestbps
  • chol
  • thalach
  • oldpeak
  • slope

对于这些特征中的每一个,我们将使用 `Normalization()` 层来确保每个特征的均值为 0,标准差为 1。

下面,我们定义一个效用函数来执行这些操作

  • 处理以独热编码字符串或整数类别型特征。
# Tensorflow required for tf.data.Dataset
import tensorflow as tf


# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
    for f in features:
        if f in CATEGORICAL_FEATURES_WITH_VOCABULARY:
            # Create a lookup to convert a string values to an integer indices.
            # Since we are not using a mask token nor expecting any out of vocabulary
            # (oov) token, we set mask_token to None and  num_oov_indices to 0.
            cls = (
                layers.StringLookup
                if features[f].dtype == "string"
                else layers.IntegerLookup
            )
            features[f] = cls(
                vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f],
                mask_token=None,
                num_oov_indices=0,
                output_mode="binary",
            )(features[f])

    # Change features from OrderedDict to Dict to match Inputs as they are Dict.
    return dict(features), target

让我们为每个数据框生成 `tf.data.Dataset` 对象

def dataframe_to_dataset(dataframe):
    dataframe = dataframe.copy()
    labels = dataframe.pop("target")
    ds = (
        tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
        .map(encode_categorical)
        .shuffle(buffer_size=len(dataframe))
    )
    return ds


train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)

每个 `Dataset` 生成一个元组 `(input, target)`,其中 `input` 是一个特征字典,`target` 的值为 0 或 1

for x, y in train_ds.take(1):
    print("Input:", x)
    print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=37>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=120>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=215>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 0, 0])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=170>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 0, 0])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>}
Target: tf.Tensor(0, shape=(), dtype=int64)

让我们对数据集进行批处理

train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)

对 Autokeras `Graph` 进行子类化

在这里,我们对 Autokeras `Graph` 进行子类化

  • `build`:我们重写此方法以处理以字典形式传递的模型 `Inputs`。在结构化数据分析中,`Inputs` 通常以字典形式传递,用于每个感兴趣的特征。
class Graph(ak.graph.Graph):

    def build(self, hp):
        """Build the HyperModel into a Keras Model."""
        keras_nodes = {}
        keras_input_nodes = []
        for node in self.inputs:
            node_id = self._node_to_id[node]
            input_node = node.build_node(hp)
            output_node = node.build(hp, input_node)
            keras_input_nodes.append(input_node)
            keras_nodes[node_id] = output_node
        for block in self.blocks:
            temp_inputs = (
                {
                    n.name: keras_nodes[self._node_to_id[n]]
                    for n in block.inputs
                    if isinstance(n, ak.Input)
                }
                if isinstance(block.inputs[0], ak.Input)
                else [keras_nodes[self._node_to_id[n]] for n in block.inputs]
            )
            outputs = tree.flatten(block.build(hp, inputs=temp_inputs))
            for n, o in zip(block.outputs, outputs):
                keras_nodes[self._node_to_id[n]] = o
        model = keras.models.Model(
            keras_input_nodes,
            [
                keras_nodes[self._node_to_id[output_node]]
                for output_node in self.outputs
            ],
        )
        return self._compile_keras_model(hp, model)

    def _compile_keras_model(self, hp, model):
        # Specify hyperparameters from compile(...)
        optimizer_name = hp.Choice(
            "optimizer",
            ["adam", "sgd"],
            default="adam",
        )
        learning_rate = hp.Choice(
            "learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
        )
        if optimizer_name == "adam":
            optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        elif optimizer_name == "sgd":
            optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
        model.compile(
            optimizer=optimizer,
            metrics=self._get_metrics(),
            loss=self._get_loss(),
        )
        return model

对 Autokeras `Input` 进行子类化

在这里,我们对 Autokeras `Input` 节点对象进行子类化,并将其 `dtype` 属性从 `None` 重写为用户提供的值。我们还重写 `build_node` 方法,以便对 `Inputs` 层使用用户提供的名称。

class Input(ak.Input):
    def __init__(self, dtype, name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        # Override dtype to a user dtype value
        self.dtype = dtype
        self.name = name

    def build_node(self, hp):
        return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)

对 `ClassificationHead` 进行子类化

在这里,我们对 Autokeras `ClassificationHead` 进行子类化,并重写 `__init__` 方法,同时添加 `get_expected_shape` 方法来推断标签的形状。我们移除了预处理功能,因为我们更倾向于手动执行此类操作。

class ClassifierHead(ak.ClassificationHead):

    def __init__(
        self,
        num_classes: Optional[int] = None,
        multi_label: bool = False,
        loss: Optional[types.LossType] = None,
        metrics: Optional[types.MetricsType] = None,
        dropout: Optional[float] = None,
        **kwargs,
    ):
        self.num_classes = num_classes
        self.multi_label = multi_label
        self.dropout = dropout
        if metrics is None:
            metrics = ["accuracy"]
        if loss is None:
            loss = self.infer_loss()
        ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs)
        self.shape = self.get_expected_shape()

    def get_expected_shape(self):
        # Compute expected shape from num_classes.
        if self.num_classes == 2 and not self.multi_label:
            return [1]
        return [self.num_classes]

`GatedLinearUnit` 层

这是在脚本 `structured_data/classification_with_grn_vsn.py` 中定义的 Keras 层。有关此层的更多详细信息可在相关脚本中找到。

class GatedLinearUnit(layers.Layer):
    def __init__(self, num_units, activation, **kwargs):
        super().__init__(**kwargs)
        self.linear = layers.Dense(num_units)
        self.sigmoid = layers.Dense(num_units, activation=activation)

    def call(self, inputs):
        return self.linear(inputs) * self.sigmoid(inputs)

    def build(self):
        self.built = True

`GatedResidualNetwork` 层

这是在脚本 `structured_data/classification_with_grn_vsn.py` 中定义的 Keras 层。有关此层的更多详细信息可在相关脚本中找到。

class GatedResidualNetwork(layers.Layer):

    def __init__(
        self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_units = num_units
        self.use_layernorm = use_layernorm
        self.elu_dense = layers.Dense(num_units, activation=activation)
        self.linear_dense = layers.Dense(num_units)
        self.dropout = layers.Dropout(dropout_rate)
        self.gated_linear_unit = GatedLinearUnit(num_units, activation)
        self.layer_norm = layers.LayerNormalization()
        self.project = layers.Dense(num_units)

    def call(self, inputs, hp):
        x = self.elu_dense(inputs)
        x = self.linear_dense(x)
        x = self.dropout(x)
        if inputs.shape[-1] != self.num_units:
            inputs = self.project(inputs)
        x = inputs + self.gated_linear_unit(x)
        use_layernorm = self.use_layernorm
        if use_layernorm is None:
            use_layernorm = hp.Boolean("use_layernorm", default=True)
        if use_layernorm:
            x = self.layer_norm(x)
        return x

    def build(self):
        self.built = True

构建 Autokeras `VariableSelection Block`

我们将以下 Keras 层转换为 Autokeras `Block`,以包含可调优的超参数。有关编写自定义 `Block` 的信息,请参阅 Autokeras blocks API。

class VariableSelection(ak.Block):
    def __init__(
        self,
        num_units: Optional[Union[int, HyperParameters.Choice]] = None,
        dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None,
        activation: Optional[Union[str, HyperParameters.Choice]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dropout = utils.get_hyperparameter(
            dropout_rate,
            HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
            float,
        )
        self.num_units = utils.get_hyperparameter(
            num_units,
            HyperParameters().Choice(
                "num_units", [16, 32, 64, 128, 256, 512, 1024], default=16
            ),
            int,
        )
        self.activation = utils.get_hyperparameter(
            activation,
            HyperParameters().Choice(
                "vsn_activation", ["sigmoid", "elu"], default="sigmoid"
            ),
            str,
        )

    def build(self, hp, inputs):
        num_units = utils.add_to_hp(self.num_units, hp, "num_units")
        dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate")
        activation = utils.add_to_hp(self.activation, hp, "activation")
        concat_inputs = []
        # Project the features to 'num_units' dimension
        for input_ in inputs:
            if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
                concat_inputs.append(
                    keras.layers.Dense(units=num_units)(inputs[input_])
                )
            else:
                # Create a Normalization layer for our feature
                normalizer = layers.Normalization()
                # Prepare a Dataset that only yields our feature
                feature_ds = train_ds.map(lambda x, y: x[input_]).map(
                    lambda x: keras.ops.expand_dims(x, -1)
                )
                # Learn the statistics of the data
                normalizer.adapt(feature_ds)
                # Normalize the input feature
                normal_feature = normalizer(inputs[input_])
                concat_inputs.append(
                    keras.layers.Dense(units=num_units)(normal_feature)
                )
        v = layers.concatenate(concat_inputs)
        v = GatedResidualNetwork(
            num_units=num_units, dropout_rate=dropout_rate, activation=activation
        )(v, hp=hp)
        v = keras.ops.expand_dims(
            layers.Dense(units=len(inputs), activation=activation)(v), axis=-1
        )
        x = []
        x += [
            GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp)
            for i in concat_inputs
        ]
        x = keras.ops.stack(x, axis=1)
        return keras.ops.squeeze(
            keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
        )

我们创建 `HyperModel`(来自 KerasTuner)的 `Inputs`,它们将被构建为 Keras `Input` 对象

# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
    inputs = {
        f: (
            Input(
                name=f,
                shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),),
                dtype="int64",
            )
            if f in CATEGORICAL_FEATURES_WITH_VOCABULARY
            else Input(name=f, shape=(1,), dtype="float32")
        )
        for f in FEATURE_NAMES
    }
    return inputs

KerasTuner `HyperModel`

在这里,我们使用 Autokeras `Functional` API 构建一个 `Blocks` 网络,该网络将最终构建为 KerasTuner `HyperModel` 和 Keras `Model`。

class MyHyperModel(keras_tuner.HyperModel):

    def build(self, hp):
        inputs = create_model_inputs()
        features = VariableSelection()(inputs)
        outputs = ClassifierHead(num_classes=2, multi_label=False)(features)
        model = Graph(inputs=inputs, outputs=outputs)
        model = model.build(hp)
        return model

    def fit(self, hp, model, *args, **kwargs):
        return model.fit(
            *args,
            # Tune whether to shuffle the data in each epoch.
            shuffle=hp.Boolean("shuffle"),
            **kwargs,
        )

使用 `RandomSearch Tuner` 寻找最佳超参数

我们使用 `RandomSearch Tuner` 在搜索空间中搜索超参数。我们还显示搜索空间。

print("Start training and searching for the best model...")

tuner = keras_tuner.RandomSearch(
    MyHyperModel(),
    objective="val_accuracy",
    max_trials=3,
    overwrite=True,
    directory="my_dir",
    project_name="tune_hypermodel",
)

# Show the search space summary
print("Tuner search space summary:\n")
tuner.search_space_summary()
# Search for best model
tuner.search(train_ds, epochs=2, validation_data=val_ds)
Trial 3 Complete [00h 00m 16s]
val_accuracy: 0.8032786846160889
Best val_accuracy So Far: 0.8032786846160889
Total elapsed time: 00h 00m 34s

提取最佳模型

# Get the top model.
models = tuner.get_best_models(num_models=1)
best_model = models[0]
best_model.summary()
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 346 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ age (InputLayer)    │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ chol (InputLayer)   │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ oldpeak             │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ slope (InputLayer)  │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ thalach             │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ trestbps            │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32     │ (None, 1)         │          0 │ age[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ ca (InputLayer)     │ (None, 4)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_2   │ (None, 1)         │          0 │ chol[0][0]        │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cp (InputLayer)     │ (None, 5)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ exang (InputLayer)  │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ fbs (InputLayer)    │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_6   │ (None, 1)         │          0 │ oldpeak[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ restecg             │ (None, 3)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ sex (InputLayer)    │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_9   │ (None, 1)         │          0 │ slope[0][0]       │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ thal (InputLayer)   │ (None, 5)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_11  │ (None, 1)         │          0 │ thalach[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_12  │ (None, 1)         │          0 │ trestbps[0][0]    │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization       │ (None, 1)         │          3 │ cast_to_float32[ │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_1   │ (None, 4)         │          0 │ ca[0][0]          │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_1     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_3   │ (None, 5)         │          0 │ cp[0][0]          │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_4   │ (None, 2)         │          0 │ exang[0][0]       │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_5   │ (None, 2)         │          0 │ fbs[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_2     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_7   │ (None, 3)         │          0 │ restecg[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_8   │ (None, 2)         │          0 │ sex[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_3     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_10  │ (None, 5)         │          0 │ thal[0][0]        │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_4     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_5     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 16)        │         32 │ normalization[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense)     │ (None, 16)        │         80 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_2 (Dense)     │ (None, 16)        │         32 │ normalization_1[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_3 (Dense)     │ (None, 16)        │         96 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_4 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_5 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_6 (Dense)     │ (None, 16)        │         32 │ normalization_2[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_7 (Dense)     │ (None, 16)        │         64 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_8 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_9 (Dense)     │ (None, 16)        │         32 │ normalization_3[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_10 (Dense)    │ (None, 16)        │         96 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_11 (Dense)    │ (None, 16)        │         32 │ normalization_4[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_12 (Dense)    │ (None, 16)        │         32 │ normalization_5[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ concatenate         │ (None, 208)       │          0 │ dense[0][0],      │
│ (Concatenate)       │                   │            │ dense_1[0][0],    │
│                     │                   │            │ dense_2[0][0],    │
│                     │                   │            │ dense_3[0][0],    │
│                     │                   │            │ dense_4[0][0],    │
│                     │                   │            │ dense_5[0][0],    │
│                     │                   │            │ dense_6[0][0],    │
│                     │                   │            │ dense_7[0][0],    │
│                     │                   │            │ dense_8[0][0],    │
│                     │                   │            │ dense_9[0][0],    │
│                     │                   │            │ dense_10[0][0],   │
│                     │                   │            │ dense_11[0][0],   │
│                     │                   │            │ dense_12[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      7,536 │ concatenate[0][0] │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_18 (Dense)    │ (None, 13)        │        221 │ gated_residual_n… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ expand_dims         │ (None, 13, 1)     │          0 │ dense_18[0][0]    │
│ (ExpandDims)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense[0][0]       │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_1[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_2[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_3[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_4[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_5[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_6[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_7[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_8[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_9[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_10[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_11[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_12[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transpose           │ (None, 1, 13)     │          0 │ expand_dims[0][0] │
│ (Transpose)         │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stack (Stack)       │ (None, 13, 16)    │          0 │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ matmul (Matmul)     │ (None, 1, 16)     │          0 │ transpose[0][0],  │
│                     │                   │            │ stack[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ squeeze (Squeeze)   │ (None, 16)        │          0 │ matmul[0][0]      │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_14          │ (None, 16)        │          0 │ squeeze[0][0]     │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_84 (Dense)    │ (None, 1)         │         17 │ dropout_14[0][0]  │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ classifier_head_1   │ (None, 1)         │          0 │ dense_84[0][0]    │
│ (Activation)        │                   │            │                   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 23,024 (89.96 KB)
 Trainable params: 23,006 (89.87 KB)
 Non-trainable params: 18 (96.00 B)

在新数据上进行推理

要获取新样本的预测结果,你只需调用 `model.predict()`。只需要做两件事

  1. 将标量封装到列表中以获得批处理维度(模型仅处理数据批次,而非单个样本)
  2. 对每个特征调用 `convert_to_tensor`
sample = {
    "age": 60,
    "sex": 1,
    "cp": 1,
    "trestbps": 145,
    "chol": 233,
    "fbs": 1,
    "restecg": 2,
    "thalach": 150,
    "exang": 0,
    "oldpeak": 2.3,
    "slope": 3,
    "ca": 0,
    "thal": "fixed",
}


# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
    # Create a list of zeros with the same length as categories
    encoding = [0] * len(cat)
    # Find the index of category_value in categories and set the corresponding position to 1
    if cat_value in cat:
        encoding[cat.index(cat_value)] = 1
    return encoding


for name, value in sample.items():
    if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
        sample.update(
            {
                name: get_cat_encoding(
                    CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
                )
            }
        )
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = best_model.predict(input_dict)

print(
    f"This particular patient had a {100 * predictions[0][0]:.1f} "
    "percent probability of having a heart disease, "
    "as evaluated by our model."
)

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step



1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 162ms/step

This particular patient had a 28.1 percent probability of having a heart disease, as evaluated by our model.