Keras 3 API 文档 / 操作 API / 核心操作

核心操作

[源码]

associative_scan 函数

keras.ops.associative_scan(f, elems, reverse=False, axis=0)

使用结合的二元运算并行执行扫描。

此操作类似于 scan,主要区别在于 associative_scan 是并行实现,可能具有显着的性能优势,尤其是在 jit 编译时。需要注意的是,它只能在 f 是二元结合运算时使用(即,它必须验证 f(a, f(b, c)) == f(f(a, b), c))。

有关结合扫描的介绍,请参阅此论文:Blelloch, Guy E. 1990. 前缀和及其应用

参数

  • f: 一个 Python 可调用对象,实现一个签名 r = f(a, b) 的结合二元运算。函数 f 必须是结合的,即,它必须满足等式 f(a, f(b, c)) == f(f(a, b), c)。输入和结果是与 elems 匹配的(可能是嵌套的 Python 树结构)数组。每个数组都有一个维度代替 axis 维度。f 应该在 axis 维度上逐元素应用。结果 r 具有与两个输入 ab 相同的形状(和结构)。
  • elems: 一个(可能是嵌套的 Python 树结构)数组,每个数组都有一个大小为 num_elemsaxis 维度。
  • reverse: 一个布尔值,表示扫描是否应该相对于 axis 维度反转。
  • axis: 一个整数,标识扫描应发生的轴。

返回

一个(可能是嵌套的 Python 树结构)数组,其形状和结构与 elems 相同,其中 axis 的第 k 个元素是将 f 递归应用于组合 elems 沿 axis 的前 k 个元素的结果。例如,给定 elems = [a, b, c, ...],结果将是 [a, f(a, b), f(f(a, b), c), ...]

示例

>>> sum_fn = lambda x, y: x + y
>>> xs = keras.ops.arange(5)
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[0, 1, 3, 6, 10]
>>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]
>>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)]
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[[1, 3], [1, 3], [1, 3]]

[源码]

cast 函数

keras.ops.cast(x, dtype)

将张量转换为所需的 dtype。

参数

  • x: 一个张量或变量。
  • dtype: 目标类型。

返回

一个指定 dtype 的张量。

示例

>>> x = keras.ops.arange(4)
>>> x = keras.ops.cast(x, dtype="float16")

[源码]

cond 函数

keras.ops.cond(pred, true_fn, false_fn)

有条件地应用 true_fnfalse_fn

参数

  • pred: 布尔标量类型
  • true_fn: 可调用对象,返回 pred == True 情况下的输出。
  • false_fn: 可调用对象,返回 pred == False 情况下的输出。

返回

true_fnfalse_fn 的输出,取决于 pred。


[源码]

convert_to_numpy 函数

keras.ops.convert_to_numpy(x)

将张量转换为 NumPy 数组。

参数

  • x: 一个张量。

返回

一个 NumPy 数组。


[源码]

convert_to_tensor 函数

keras.ops.convert_to_tensor(x, dtype=None, sparse=None)

将 NumPy 数组转换为张量。

参数

  • x: 一个 NumPy 数组,Python 数组(可以是嵌套的)或后端张量。
  • dtype: 目标类型。如果为 None,则使用 x 的类型。
  • sparse: 是否保留稀疏张量。False 将导致稀疏张量被密集化。None 的默认值意味着只有在后端支持时才保留稀疏张量。

返回

一个指定 dtype 和稀疏度的后端张量。

示例

>>> x = np.array([1, 2, 3])
>>> y = keras.ops.convert_to_tensor(x)

[源码]

custom_gradient 函数

keras.ops.custom_gradient(f)

用于定义具有自定义梯度的函数的装饰器。

此装饰器允许对一系列操作的梯度进行细粒度控制。这可能出于多种原因有用,包括为一系列操作提供更有效或数值稳定的梯度。

参数

  • f: 函数 f(*args),返回一个元组 (output, grad_fn),其中
    • args 是(嵌套的)张量输入序列。
    • output 是将 forward_fn 中的操作应用于 args 的(嵌套的)张量输出。
    • grad_fn 是一个函数,其签名为 grad_fn(*args, upstream),它返回一个与(扁平化的)args 大小相同的张量元组:output 中张量相对于 args 中张量的导数。upstream 是一个张量或张量序列,保存 output 中每个张量的初始值梯度。

返回

一个函数 h(*args),它返回与 f(*args)[0] 相同的值,其梯度由 f(*args)[1] 确定。

示例

  1. 与后端无关的示例。
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)

    def grad(*args, upstream=None):
        if upstream is None:
            (upstream,) = args
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

    return ops.log(1 + e), grad

请注意,返回梯度计算的 grad 函数需要 args 以及 upstream 关键字参数,具体取决于设置的后端。对于 JAX 和 TensorFlow 后端,它只需要一个参数,而在 PyTorch 后端的情况下,它可能会使用 upstream 参数。

在使用 TensorFlow/JAX 后端时,grad(upstream) 就足够了。对于 PyTorch,grad 函数需要 *args 以及 upstream,例如 def grad(*args, upstream)。请按照前面的示例以与所有后端兼容的方式使用 @ops.custom_gradient

  1. 这是 JAX 和 TensorFlow 特有的示例
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad
  1. 最后,这是 PyTorch 特有的示例,使用 *argsupstream
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(*args, upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad

[源码]

dtype 函数

keras.ops.dtype(x)

以标准化的字符串形式返回张量输入的 dtype。

请注意,由于标准化,dtype 将不与 dtype 的后端特定版本进行比较。

参数

  • x: 一个张量。此函数将尝试访问输入张量的 dtype 属性。

返回

一个字符串,指示输入张量的 dtype,例如 "float32"

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.dtype(x)
'float32'

[源码]

erf 函数

keras.ops.erf(x)

逐元素计算 x 的误差函数。

参数

  • x: 输入张量。

返回

一个与 x 具有相同 dtype 的张量。

示例

>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])
>>> keras.ops.erf(x)
array([-0.99998 , -0.99532, -0.842701,  0.,  0.842701], dtype=float32)

[源码]

erfinv 函数

keras.ops.erfinv(x)

逐元素计算 x 的逆误差函数。

参数

  • x: 输入张量。

返回

一个与 x 具有相同 dtype 的张量。

示例

>>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
>>> keras.ops.erfinv(x)
array([-0.47694, -0.17914, -0.08886,  0. ,  0.27246], dtype=float32)

[源码]

extract_sequences 函数

keras.ops.extract_sequences(x, sequence_length, sequence_stride)

将最后一个轴的维度扩展为长度为 sequence_length 的序列。

在输入的最后一个轴上滑动大小为 sequence_length 的窗口,步长为 sequence_stride,并将最后一个轴替换为 [num_sequences, sequence_length] 序列。

如果沿最后一个轴的维度为 N,则序列的数量可以计算为

num_sequences = 1 + (N - sequence_length) // sequence_stride

参数

  • x: 输入张量。
  • sequence_length: 一个表示序列长度的整数。
  • sequence_stride: 一个表示序列跳跃大小的整数。

返回

一个形状为 [..., num_sequences, sequence_length] 的序列张量。

示例

>>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
>>> extract_sequences(x, 3, 2)
array([[1, 2, 3],
   [3, 4, 5]])

[源码]

fori_loop 函数

keras.ops.fori_loop(lower, upper, body_fun, init_val)

For 循环实现。

参数

  • lower: 循环变量的初始值。
  • upper: 循环变量的上限。
  • body_fun: 一个表示循环主体的可调用对象。必须接受两个参数:循环变量和循环状态。循环状态应由此函数更新和返回。
  • init_val: 循环状态的初始值。

返回

循环后的最终状态。

示例

>>> lower = 0
>>> upper = 10
>>> body_fun = lambda i, s: (i + 1, s + i)
>>> init_val = 0
>>> keras.ops.fori_loop(lower, upper, body_fun, init_val)
45

[源码]

in_top_k 函数

keras.ops.in_top_k(targets, predictions, k)

检查目标是否在前 k 个预测中。

参数

  • targets: 真实标签的张量。
  • predictions: 预测标签的张量。
  • k: 一个整数,表示要考虑的预测数量。

返回

一个与 targets 形状相同的布尔张量,其中每个元素指示相应的目标是否在前 k 个预测中。

示例

>>> targets = keras.ops.convert_to_tensor([2, 5, 3])
>>> predictions = keras.ops.convert_to_tensor(
... [[0.1, 0.4, 0.6, 0.9, 0.5],
...  [0.1, 0.7, 0.9, 0.8, 0.3],
...  [0.1, 0.6, 0.9, 0.9, 0.5]])
>>> in_top_k(targets, predictions, k=3)
array([ True False  True], shape=(3,), dtype=bool)

[源码]

is_tensor 函数

keras.ops.is_tensor(x)

检查给定对象是否为张量。

注意:这将检查后端特定的张量,因此如果您的后端是 PyTorch 或 JAX,则传递 TensorFlow 张量将返回 False

参数

  • x: 一个变量。

返回

如果 x 是张量,则为 True,否则为 False


[源码]

logsumexp 函数

keras.ops.logsumexp(x, axis=None, keepdims=False)

计算张量中元素的指数和的对数。

参数

  • x: 输入张量。
  • axis: 一个整数或整数元组,指定计算和的轴。如果为 None,则对所有元素求和。默认为 None
  • keepdims: 一个布尔值,指示在计算总和时是否保留输入张量的维度。默认为 False

返回

一个张量,包含 x 中元素的指数和的对数。

示例

>>> x = keras.ops.convert_to_tensor([1., 2., 3.])
>>> logsumexp(x)
3.407606

[源码]

map 函数

keras.ops.map(f, xs)

将函数映射到前导数组轴上。

类似于 Python 的内置 map,只是输入和输出采用堆叠数组的形式。考虑使用 vectorized_map() 转换,除非您需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算。

xs 是数组类型时,map() 的语义由以下 Python 实现给出

def map(f, xs):
    return np.stack([f(x) for x in xs])

参数

  • f: 可调用对象,定义要对 xs 的第一个轴或多个轴逐元素应用的函数。
  • xs: 要沿前导轴映射的值。

返回

映射后的值。

示例

>>> f = lambda x: x**2
>>> xs = keras.ops.arange(10)
>>> ys = keras.ops.map(f, xs)
>>> ys
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> f = lambda x: {"y1": x**2, "y2": x * 10}  # Can have nested outputs
>>> ys = keras.ops.map(f, xs)
>>> ys["y1"]
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> ys["y2"]
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

[源码]

rsqrt 函数

keras.ops.rsqrt(x)

逐元素计算 x 的平方根的倒数。

参数

  • x: 输入张量

返回

一个与 x 具有相同 dtype 的张量。

示例

>>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])
>>> keras.ops.rsqrt(x)
array([1.0, 0.31622776, 0.1], dtype=float32)

[源码]

saturate_cast 函数

keras.ops.saturate_cast(x, dtype)

执行到所需 dtype 的安全饱和转换。

当转换为具有较小值范围的 dtype 时,饱和转换可以防止数据类型溢出。例如,ops.cast(ops.cast([-1, 256], "float32"), "uint8") 返回 [255, 0],但 ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8") 返回 [0, 255]

参数

  • x: 一个张量或变量。
  • dtype: 目标类型。

返回

一个安全转换为指定 dtype 的张量。

示例

使用双三次插值调整图像大小时可能会产生超出原始范围的值。

>>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
>>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
>>> print(image4x4.numpy().squeeze())
>>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
>>> #  [ 52.526054  52.82143   53.407146  53.70253 ]
>>> #  [201.29752  201.59288  202.17859  202.47395 ]
>>> #  [276.32355  276.61893  277.20465  277.50006 ]]

将此调整大小后的图像转换回 uint8 会导致溢出。

>>> image4x4_casted = ops.cast(image4x4, "uint8")
>>> print(image4x4_casted.numpy().squeeze())
>>> # [[234 234 235 235]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [ 20  20  21  21]]

饱和转换为 uint8 将在转换之前将值裁剪到 uint8 范围,并且不会导致溢出。

>>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
>>> print(image4x4_saturate_casted.numpy().squeeze())
>>> # [[  0   0   0   0]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [255 255 255 255]]

[源码]

scan 函数

keras.ops.scan(f, init, xs=None, length=None, reverse=False, unroll=1)

在领先的数组轴上扫描一个函数,同时携带状态。

xs 的类型为数组类型或 None,并且 ys 的类型为数组类型时,scan() 的语义大致由以下 Python 实现给出

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

循环携带值 carry (init) 在所有迭代中必须保持固定的形状和数据类型。

在 TensorFlow 中,y 的形状和数据类型必须与 carry 匹配。在其他后端中,这不是必需的。

参数

  • f:可调用对象,定义每次循环迭代的逻辑。它接受两个参数,第一个是循环携带值,第二个是 xs 沿其主轴的切片。此可调用对象返回一个对,其中第一个表示循环携带的新值,第二个表示输出的切片。
  • init:初始循环携带值。它可以是标量、张量或任何嵌套结构。它必须与 f 返回的第一个元素的结构匹配。
  • xs:可选值,沿其主轴扫描。它可以是张量或任何嵌套结构。如果未提供 xs,则必须指定 length 来定义循环迭代的次数。默认为 None
  • length:可选整数,指定循环迭代的次数。如果未提供 length,则默认为 xs 中数组主轴的大小。默认为 None
  • reverse:可选布尔值,指定是正向运行扫描迭代还是反向运行,相当于反转 xsys 中数组的主轴。
  • unroll:可选正整数或布尔值,指定在一个循环的单次迭代中展开多少次扫描迭代。如果提供整数,则确定在一个循环的单次迭代中运行多少次展开的循环迭代。如果提供布尔值,则确定循环是完全展开 (unroll=True) 还是完全不展开 (unroll=False)。请注意,展开仅受 JAX 和 TensorFlow 后端支持。

返回

一个对,其中第一个元素表示最终循环携带值,第二个元素表示在输入的主轴上扫描时 f 的堆叠输出。

示例

>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = keras.ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
[1, 3, 6, 10, 15]

[源码]

scatter 函数

keras.ops.scatter(indices, values, shape)

返回一个形状为 shape 的张量,其中 indices 设置为 values

从高层次来看,此操作执行 zeros[indices] = updates 并返回输出。它等价于

zeros = keras.ops.zeros(shape)
output = keras.ops.scatter_update(zeros, indices, values)

参数

  • indices:一个张量或列表/元组,指定 values 中值的索引。
  • values:一个张量,是要在 indices 处设置的值。
  • shape:输出张量的形状。

示例

>>> indices = [[0, 1], [1, 1]]
>>> values = np.array([1., 1.])
>>> keras.ops.scatter(indices, values, shape=(2, 2))
array([[0., 1.],
       [0., 1.]])

[源码]

scatter_update 函数

keras.ops.scatter_update(inputs, indices, updates)

通过分散(稀疏)索引处的更新来更新输入。

从高层次来看,此操作执行 inputs[indices] = updates。假设 inputs 是一个形状为 (D0, D1, ..., Dn) 的张量,scatter_update 有 2 个主要用法。

  1. indices 是一个形状为 (num_updates, n) 的 2D 张量,其中 num_updates 是要执行的更新次数,而 updates 是一个形状为 (num_updates,) 的 1D 张量。例如,如果 inputszeros((4, 4, 4)),并且我们想将 inputs[1, 2, 3]inputs[0, 1, 3] 更新为 1,那么我们可以使用
inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.])
inputs = keras.ops.scatter_update(inputs, indices, updates)

2 indices 是一个形状为 (num_updates, k) 的 2D 张量,其中 num_updates 是要执行的更新次数,而 k (k < n) 是 indices 中每个索引的大小。updates 是一个形状为 (num_updates, inputs.shape[k:])n - k-D 张量。例如,如果 inputs = np.zeros((4, 4, 4)),并且我们想将 inputs[1, 2, :]inputs[2, 3, :] 更新为 [1, 1, 1, 1],那么 indices 的形状将为 (num_updates, 2) (k = 2),而 updates 的形状将为 (num_updates, 4) (inputs.shape[2:] = 4)。请参见下面的代码

inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras.ops.scatter_update(inputs, indices, updates)

参数

  • inputs:一个张量,是要更新的张量。
  • indices:一个形状为 (N, inputs.ndim) 的张量或列表/元组,指定要更新的索引。N 是要更新的索引的数量,必须等于 updates 的第一个维度。
  • updates:一个张量,是要在 indices 处放入 inputs 的新值。

返回

一个张量,具有与 inputs 相同的形状和数据类型。


[源码]

segment_max 函数

keras.ops.segment_max(data, segment_ids, num_segments=None, sorted=False)

计算张量中段的最大值。

参数

  • data:输入张量。
  • segment_ids:一个 N 维张量,包含 data 中每个元素的段索引。data.shape[:len(segment_ids.shape)] 应该匹配。
  • num_segments:一个整数,表示段的总数。如果未指定,则从 segment_ids 中的最大值推断得出。
  • sorted:一个布尔值,指示 segment_ids 是否已排序。默认为 False

返回

一个包含段最大值的张量,其中每个元素表示 data 中相应段的最大值。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_max(data, segment_ids, num_segments)
array([2, 20, 200], dtype=int32)

[源码]

segment_sum 函数

keras.ops.segment_sum(data, segment_ids, num_segments=None, sorted=False)

计算张量中段的总和。

参数

  • data:输入张量。
  • segment_ids:一个 N 维张量,包含 data 中每个元素的段索引。段 ID 的维度数应严格小于或等于数据中的维度数。
  • num_segments:一个整数,表示段的总数。如果未指定,则从 segment_ids 中的最大值推断得出。
  • sorted:一个布尔值,指示 segment_ids 是否已排序。默认为 False

返回

一个包含段总和的张量,其中每个元素表示 data 中相应段的总和。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_sum(data, segment_ids,num_segments)
array([3, 30, 300], dtype=int32)

[源码]

shape 函数

keras.ops.shape(x)

获取张量输入的形状。

注意:在 TensorFlow 后端,当 x 是一个具有动态形状的 tf.Tensor 时,在编译函数的上下文中是动态的维度将具有 tf.Tensor 值,而不是静态整数值。

参数

  • x:一个张量。此函数将尝试访问输入张量的 shape 属性。

返回

一个整数或 None 值的元组,指示输入张量的形状。

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.shape(x)
(8, 12)

[源码]

slice 函数

keras.ops.slice(inputs, start_indices, shape)

返回输入张量的切片。

从高层次来看,此操作是数组切片的显式替代,例如 inputs[start_indices: start_indices + shape]。与通过括号切片不同,此操作将在所有后端接受张量起始索引,这在通过其他张量操作动态计算索引时非常有用。

inputs = np.zeros((5, 5))
start_indices = np.array([3, 3])
shape = np.array([2, 2])
inputs = keras.ops.slice(inputs, start_indices, shape)

参数

  • inputs:一个张量,是要更新的张量。
  • start_indices:一个形状为 (inputs.ndim,) 的列表/元组,指定更新的起始索引。
  • shape:返回切片的完整形状。

返回

一个张量,具有与 inputs 相同的形状和数据类型。


[源码]

slice_update 函数

keras.ops.slice_update(inputs, start_indices, updates)

通过切入已更新值的张量来更新输入。

从高层次来看,此操作执行 inputs[start_indices: start_indices + updates.shape] = updates。假设 inputs 是一个形状为 (D0, D1, ..., Dn) 的张量,start_indices 必须是 n 个整数的列表/元组,指定起始索引。updates 的秩必须与 inputs 相同,并且每个维的大小不得超过 Di - start_indices[i]。例如,如果我们有 2D 输入 inputs = np.zeros((5, 5)),并且我们想将最后 2 行和最后 2 列的交集更新为 1,即 inputs[3:, 3:] = np.ones((2, 2)),那么我们可以使用下面的代码

inputs = np.zeros((5, 5))
start_indices = [3, 3]
updates = np.ones((2, 2))
inputs = keras.ops.slice_update(inputs, start_indices, updates)

参数

  • inputs:一个张量,是要更新的张量。
  • start_indices:一个形状为 (inputs.ndim,) 的列表/元组,指定更新的起始索引。
  • updates:一个张量,是要在 indices 处放入 inputs 的新值。updates 的秩必须与 inputs 相同。

返回

一个张量,具有与 inputs 相同的形状和数据类型。


[源码]

stop_gradient 函数

keras.ops.stop_gradient(variable)

停止梯度计算。

参数

  • variable:要禁用梯度计算的张量变量。

返回

禁用梯度计算的变量。

示例

>>> var = keras.backend.convert_to_tensor(
...     [1., 2., 3.],
...     dtype="float32"
... )
>>> var = keras.ops.stop_gradient(var)

[源码]

switch 函数

keras.ops.switch(index, branches, *operands)

根据 index 应用 branches 中的一个分支。

如果 index 超出范围,则会将其钳制在范围内。

switch 的语义大致由以下 Python 实现给出

def switch(index, branches, *operands):
    index = clamp(0, index, len(branches) - 1)
    return branches[index](*operands)

参数

  • index:一个整数标量,指示要应用哪个分支函数。
  • branches:一个基于 index 应用的函数序列。
  • operands:应用于任何分支的输入。

返回

基于 index 选择的分支的 branch(*operands) 的输出。

示例

>>> add_fn = lambda x, y: x + y
>>> subtract_fn = lambda x, y: x - y
>>> x = keras.ops.array(2.0)
>>> y = keras.ops.array(0.5)
>>> branches = [add_fn, subtract_fn]
>>> keras.ops.switch(0, branches, x, y)
2.5
>>> keras.ops.switch(1, branches, x, y)
1.5

[源码]

top_k 函数

keras.ops.top_k(x, k, sorted=True)

查找张量中前 k 个值及其索引。

参数

  • x: 输入张量。
  • k:一个整数,表示要检索的前几个元素的数量。
  • sorted:一个布尔值,指示是否以降序对输出进行排序。默认为 True

返回

一个包含两个张量的元组。第一个张量包含前 k 个值,第二个张量包含输入张量中前 k 个值的索引。

示例

>>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
>>> values, indices = top_k(x, k=3)
>>> print(values)
array([9 7 5], shape=(3,), dtype=int32)
>>> print(indices)
array([4 2 0], shape=(3,), dtype=int32)

[源码]

unstack 函数

keras.ops.unstack(x, num=None, axis=0)

将秩为 R 的张量的给定维度解包为秩为 (R-1) 的张量。

参数

  • x:输入张量。
  • num:维度轴的长度。如果为 None,则自动推断。
  • axis:要解包的轴。

返回

一个沿着给定轴解包的张量列表。

示例

>>> x = keras.ops.array([[1, 2], [3, 4]])
>>> keras.ops.unstack(x, axis=0)
[array([1, 2]), array([3, 4])]

[源码]

vectorized_map 函数

keras.ops.vectorized_map(function, elements)

在张量 elements 的轴 0 上并行映射 function

示意性地,在单个张量输入 elements 的情况下,vectorized_map 实现如下

def vectorized_map(function, elements)
    outputs = []
    for e in elements:
        outputs.append(function(e))
    return stack(outputs)

在张量 elements 的可迭代对象的情况下,它实现如下

def vectorized_map(function, elements)
    batch_size = elements[0].shape[0]
    outputs = []
    for index in range(batch_size):
        outputs.append(function([e[index] for e in elements]))
    return np.stack(outputs)

在这种情况下,预期 function 将单个张量参数列表作为输入。


[源码]

while_loop 函数

keras.ops.while_loop(cond, body, loop_vars, maximum_iterations=None)

While 循环实现。

参数

  • cond:一个可调用对象,表示循环的终止条件。必须接受类似 loop_vars 的结构作为参数。如果 loop_vars 是一个元组或列表,则 loop_vars 的每个元素将按位置传递给可调用对象。
  • body:一个可调用对象,表示循环主体。必须接受类似 loop_vars 的结构作为参数,并返回具有相同结构的更新值。如果 loop_vars 是一个元组或列表,则 loop_vars 的每个元素将按位置传递给可调用对象。
  • loop_vars:一个任意嵌套的张量状态结构,可在循环迭代中保留。
  • maximum_iterations:while 循环要运行的可选最大迭代次数。如果提供,则 cond 输出将与一个附加条件进行 AND 运算,以确保执行的迭代次数不大于 maximum_iterations

返回

一个张量列表/元组,其形状和数据类型与 inputs 相同。

示例

>>> i = 0
>>> cond = lambda i: i < 10
>>> body = lambda i: i + 1
>>> keras.ops.while_loop(cond, body, i)
10
>>> x, y = 0, 1
>>> cond = lambda x, y: x < 10
>>> body = lambda x, y: (x + 1, y + 1)
>>> keras.ops.while_loop(cond, body, (x, y))
10, 11