Keras 3 API 文档 / 操作 API / 核心操作

核心操作

[源代码]

associative_scan 函数

keras.ops.associative_scan(f, elems, reverse=False, axis=0)

并行执行具有关联二元运算的扫描。

此操作类似于 scan,关键区别在于 associative_scan 是一个并行实现,具有潜在的显著性能优势,尤其是在 JIT 编译时。问题是,它只能在 f 是一个二元关联运算(即必须验证 f(a, f(b, c)) == f(f(a, b), c))时使用。

有关关联扫描的介绍,请参阅本文档:Blelloch,Guy E. 1990. 前缀和及其应用

参数

  • f:一个 Python 可调用对象,实现一个具有签名 r = f(a, b) 的关联二元运算。函数 f 必须是关联的,即必须满足方程 f(a, f(b, c)) == f(f(a, b), c)。输入和结果是(可能是嵌套的 Python 树结构的)数组,与 elems 匹配。每个数组在 axis 维度上都有一个维度。f 应该在 axis 维度上逐元素应用。结果 r 与两个输入 ab 具有相同的形状(和结构)。
  • elems:一个(可能是嵌套的 Python 树结构的)数组,每个数组都有一个大小为 num_elemsaxis 维度。
  • reverse:一个布尔值,表示扫描是否应相对于 axis 维度反转。
  • axis:一个整数,标识应在其上执行扫描的轴。

返回值

一个(可能是嵌套的 Python 树结构的)数组,其形状和结构与 elems 相同,其中 axis 的第 k 个元素是递归应用 f 以组合 elems 沿 axis 的前 k 个元素的结果。例如,给定 elems = [a, b, c, ...],结果将是 [a, f(a, b), f(f(a, b), c), ...]

示例

>>> sum_fn = lambda x, y: x + y
>>> xs = keras.ops.arange(5)
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[0, 1, 3, 6, 10]
>>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]
>>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)]
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[[1, 3], [1, 3], [1, 3]]

[源代码]

cast 函数

keras.ops.cast(x, dtype)

将张量转换为所需的 dtype。

参数

  • x:一个张量或变量。
  • dtype:目标类型。

返回值

指定 dtype 的张量。

示例

>>> x = keras.ops.arange(4)
>>> x = keras.ops.cast(x, dtype="float16")

[源代码]

cond 函数

keras.ops.cond(pred, true_fn, false_fn)

有条件地应用 true_fnfalse_fn

参数

  • pred:布尔标量类型
  • true_fn:一个可调用对象,返回 pred == True 情况下的输出。
  • false_fn:一个可调用对象,返回 pred == False 情况下的输出。

返回值

根据 pred,返回 true_fnfalse_fn 的输出。


[源代码]

convert_to_numpy 函数

keras.ops.convert_to_numpy(x)

将张量转换为 NumPy 数组。

参数

  • x:一个张量。

返回值

一个 NumPy 数组。


[源代码]

convert_to_tensor 函数

keras.ops.convert_to_tensor(x, dtype=None, sparse=None)

将 NumPy 数组转换为张量。

参数

  • x:一个 NumPy 数组、Python 数组(可以嵌套)或后端张量。
  • dtype:目标类型。如果为 None,则使用 x 的类型。
  • sparse:是否保留稀疏张量。False 将导致稀疏张量被致密化。None 的默认值表示仅当后端支持时才保留稀疏张量。

返回值

指定 dtype 和稀疏性的后端张量。

示例

>>> x = np.array([1, 2, 3])
>>> y = keras.ops.convert_to_tensor(x)

[源代码]

custom_gradient 函数

keras.ops.custom_gradient(f)

装饰器,用于定义具有自定义梯度的函数。

此装饰器允许对操作序列的梯度进行细粒度控制。这可能出于多种原因有用,包括为一系列操作提供更有效或数值更稳定的梯度。

参数

  • f:函数 f(*args) 返回一个元组 (output, grad_fn),其中
    • args 是函数的(嵌套结构的)张量输入序列。
    • output 是将 forward_fn 中的操作应用于 args 的(嵌套结构的)张量输出。
    • grad_fn 是一个具有签名 grad_fn(*args, upstream) 的函数,它返回一个与(扁平化)args 大小相同的张量元组:output 中张量相对于 args 中张量的导数。upstream 是一个张量或张量序列,包含每个 output 中张量的初始值梯度。

返回值

一个函数 h(*args),它返回与 f(*args)[0] 相同的值,其梯度由 f(*args)[1] 确定。

示例

  1. 与后端无关的示例。
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)

    def grad(*args, upstream=None):
        if upstream is None:
            (upstream,) = args
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

    return ops.log(1 + e), grad

请注意,返回梯度计算的 grad 函数需要 args 以及一个 upstream 关键字参数,具体取决于设置的后端。对于 JAX 和 TensorFlow 后端,它只需要一个参数,而在 PyTorch 后端的情况下,它可能会使用 upstream 参数。

在使用 TensorFlow/JAX 后端时,grad(upstream) 就足够了。对于 PyTorch,grad 函数需要 *args 以及 upstream,例如 def grad(*args, upstream)。请遵循前面的示例以兼容所有后端的方式使用 @ops.custom_gradient

  1. 以下是 JAX 和 TensorFlow 特定的示例
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad
  1. 最后,这是一个 PyTorch 特定的示例,使用 *argsupstream
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(*args, upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad

[源代码]

dtype 函数

keras.ops.dtype(x)

将张量输入的 dtype 作为标准化字符串返回。

请注意,由于标准化,dtype 将不会与 dtype 的后端特定版本相等。

参数

  • x:一个张量。此函数将尝试访问输入张量的 dtype 属性。

返回值

一个指示输入张量 dtype 的字符串,例如 "float32"

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.dtype(x)
'float32'

[源代码]

erf 函数

keras.ops.erf(x)

逐元素计算 x 的误差函数。

参数

  • x:输入张量。

返回值

一个与 x 具有相同 dtype 的张量。

示例

>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])
>>> keras.ops.erf(x)
array([-0.99998 , -0.99532, -0.842701,  0.,  0.842701], dtype=float32)

[源代码]

erfinv 函数

keras.ops.erfinv(x)

逐元素计算 x 的逆误差函数。

参数

  • x:输入张量。

返回值

一个与 x 具有相同 dtype 的张量。

示例

>>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
>>> keras.ops.erfinv(x)
array([-0.47694, -0.17914, -0.08886,  0. ,  0.27246], dtype=float32)

[源代码]

extract_sequences 函数

keras.ops.extract_sequences(x, sequence_length, sequence_stride)

将最后一个轴的维度扩展为 sequence_length 的序列。

sequence_stride 的步长在输入的最后一个轴上滑动大小为 sequence_length 的窗口,将最后一个轴替换为 [num_sequences, sequence_length] 序列。

如果最后一个轴上的维度为 N,则序列数可以计算为

num_sequences = 1 + (N - sequence_length) // sequence_stride

参数

  • x:输入张量。
  • sequence_length:一个表示序列长度的整数。
  • sequence_stride:一个表示序列跳跃大小的整数。

返回值

一个形状为 [..., num_sequences, sequence_length] 的序列张量。

示例

>>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
>>> extract_sequences(x, 3, 2)
array([[1, 2, 3],
   [3, 4, 5]])

[源代码]

fori_loop 函数

keras.ops.fori_loop(lower, upper, body_fun, init_val)

For 循环实现。

参数

  • lower:循环变量的初始值。
  • upper:循环变量的上限。
  • body_fun:一个可调用对象,表示循环体。必须接受两个参数:循环变量和循环状态。此函数应更新和返回循环状态。
  • init_val:循环状态的初始值。

返回值

循环结束后最终的状态。

示例

>>> lower = 0
>>> upper = 10
>>> body_fun = lambda i, s: (i + 1, s + i)
>>> init_val = 0
>>> keras.ops.fori_loop(lower, upper, body_fun, init_val)
45

[源代码]

in_top_k 函数

keras.ops.in_top_k(targets, predictions, k)

检查目标是否在 top-k 预测中。

参数

  • targets:一个真实标签张量。
  • predictions:一个预测标签张量。
  • k:一个表示要考虑的预测数量的整数。

返回值

一个与 targets 形状相同的布尔张量,其中每个元素指示相应的目标是否在 top-k 预测中。

示例

>>> targets = keras.ops.convert_to_tensor([2, 5, 3])
>>> predictions = keras.ops.convert_to_tensor(
... [[0.1, 0.4, 0.6, 0.9, 0.5],
...  [0.1, 0.7, 0.9, 0.8, 0.3],
...  [0.1, 0.6, 0.9, 0.9, 0.5]])
>>> in_top_k(targets, predictions, k=3)
array([ True False  True], shape=(3,), dtype=bool)

[源代码]

is_tensor 函数

keras.ops.is_tensor(x)

检查给定对象是否为张量。

注意:这检查后端特定的张量,因此如果您的后端是 PyTorch 或 JAX,则传递 TensorFlow 张量将返回 False

参数

  • x:一个变量。

返回值

如果 x 是张量,则返回 True,否则返回 False


[源代码]

logsumexp 函数

keras.ops.logsumexp(x, axis=None, keepdims=False)

计算张量中元素的指数之和的对数。

参数

  • x:输入张量。
  • axis:一个整数或整数元组,指定计算总和的轴/轴。如果为 None,则在所有元素上计算总和。默认为 None
  • keepdims:一个布尔值,指示在计算总和时是否保留输入张量的维度。默认为 False

返回值

一个包含 x 中元素的指数之和的对数的张量。

示例

>>> x = keras.ops.convert_to_tensor([1., 2., 3.])
>>> logsumexp(x)
3.407606

[源代码]

map 函数

keras.ops.map(f, xs)

在领先的数组轴上映射函数。

类似于 Python 的内置 map,除了输入和输出采用堆叠数组的形式。请考虑使用 vectorized_map() 变换,除非您需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算。

xs 是数组类型时,map() 的语义由以下 Python 实现给出

def map(f, xs):
    return np.stack([f(x) for x in xs])

参数

  • f:可调用对象,定义要沿 xs 的第一个轴或轴逐元素应用的函数。
  • xs:沿前导轴映射的值。

返回值

映射的值。

示例

>>> f = lambda x: x**2
>>> xs = keras.ops.arange(10)
>>> ys = keras.ops.map(f, xs)
>>> ys
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> f = lambda x: {"y1": x**2, "y2": x * 10}  # Can have nested outputs
>>> ys = keras.ops.map(f, xs)
>>> ys["y1"]
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> ys["y2"]
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

[源代码]

rsqrt 函数

keras.ops.rsqrt(x)

逐元素计算 x 的平方根的倒数。

参数

  • x:输入张量

返回值

一个与 x 具有相同 dtype 的张量。

示例

>>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])
>>> keras.ops.rsqrt(x)
array([1.0, 0.31622776, 0.1], dtype=float32)

[源代码]

saturate_cast 函数

keras.ops.saturate_cast(x, dtype)

执行安全饱和转换到所需的 dtype。

饱和转换在转换为范围较小的dtype时防止数据类型溢出。例如,ops.cast(ops.cast([-1, 256], "float32"), "uint8")返回[255, 0],但ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")返回[0, 255]

参数

  • x:一个张量或变量。
  • dtype:目标类型。

返回值

一个安全转换后的指定dtype的张量。

示例

使用双三次插值进行图像缩放可能会产生超出原始范围的值。

>>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
>>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
>>> print(image4x4.numpy().squeeze())
>>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
>>> #  [ 52.526054  52.82143   53.407146  53.70253 ]
>>> #  [201.29752  201.59288  202.17859  202.47395 ]
>>> #  [276.32355  276.61893  277.20465  277.50006 ]]

将此缩放后的图像转换回uint8将导致溢出。

>>> image4x4_casted = ops.cast(image4x4, "uint8")
>>> print(image4x4_casted.numpy().squeeze())
>>> # [[234 234 235 235]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [ 20  20  21  21]]

饱和转换到uint8将在转换前将值裁剪到uint8范围,并且不会导致溢出。

>>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
>>> print(image4x4_saturate_casted.numpy().squeeze())
>>> # [[  0   0   0   0]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [255 255 255 255]]

[源代码]

scan 函数

keras.ops.scan(f, init, xs=None, length=None, reverse=False, unroll=1)

在保留状态的同时,沿着前导数组轴扫描函数。

xs的类型是数组类型或None,并且ys的类型是数组类型时,scan()的语义由以下Python实现大致给出

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

循环携带的值carryinit)必须在所有迭代中保持固定的形状和数据类型。

在TensorFlow中,y的形状和数据类型必须与carry匹配。在其他后端不需要。

参数

  • f:Callable 定义每次循环迭代的逻辑。它接受两个参数,第一个是循环携带的值,第二个是沿其前导轴的xs的切片。此可调用对象返回一对,其中第一个表示循环携带的新值,第二个表示输出的切片。
  • init:初始循环携带值。这可以是标量、张量或任何嵌套结构。它必须与f返回的第一个元素的结构匹配。
  • xs:可选值,沿着其前导轴扫描。这可以是张量或任何嵌套结构。如果未提供xs,则必须指定length来定义循环迭代次数。默认为None
  • length:可选整数,指定循环迭代次数。如果未提供length,则默认为xs中数组的前导轴的大小。默认为None
  • reverse:可选布尔值,指定是向前运行扫描迭代还是反向运行,相当于反转xsys中数组的前导轴。
  • unroll:可选正整数或布尔值,指定在循环的单个迭代中展开多少个扫描迭代。如果提供整数,则它确定在循环的单个卷绕迭代中运行多少个展开的循环迭代。如果提供布尔值,它将确定循环是否完全展开(unroll=True)或完全不展开(unroll=False)。请注意,展开仅受JAX和TensorFlow后端支持。

返回值

一对,其中第一个元素表示最终的循环携带值,第二个元素表示在输入的前导轴上扫描时f的堆叠输出。

示例

>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = keras.ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
[1, 3, 6, 10, 15]

[源代码]

scatter 函数

keras.ops.scatter(indices, values, shape)

返回形状为shape的张量,其中indices设置为values

从高级别来看,此操作执行zeros[indices] = updates并返回输出。它等价于

zeros = keras.ops.zeros(shape)
output = keras.ops.scatter_update(zeros, indices, values)

参数

  • indices:指定values中值的索引的张量或列表/元组。
  • values:一个张量,要在indices处设置的值。
  • shape:输出张量的形状。

示例

>>> indices = [[0, 1], [1, 1]]
>>> values = np.array([1., 1.])
>>> keras.ops.scatter(indices, values, shape=(2, 2))
array([[0., 1.],
       [0., 1.]])

[源代码]

scatter_update 函数

keras.ops.scatter_update(inputs, indices, updates)

通过散列(稀疏)索引处的更新来更新输入。

从高级别来看,此操作执行inputs[indices] = updates。假设inputs是一个形状为(D0, D1, ..., Dn)的张量,scatter_update有两种主要用法。

  1. indices是一个形状为(num_updates, n)的二维张量,其中num_updates是要执行的更新次数,而updates是一个形状为(num_updates,)的一维张量。例如,如果inputszeros((4, 4, 4)),并且我们想要将inputs[1, 2, 3]inputs[0, 1, 3]更新为1,那么我们可以使用
inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.])
inputs = keras.ops.scatter_update(inputs, indices, updates)

2 indices是一个形状为(num_updates, k)的二维张量,其中num_updates是要执行的更新次数,而kk < n)是indices中每个索引的大小。updates是一个n - k维张量,形状为(num_updates, inputs.shape[k:])。例如,如果inputs = np.zeros((4, 4, 4)),并且我们想要将inputs[1, 2, :]inputs[2, 3, :]更新为[1, 1, 1, 1],那么indices的形状将为(num_updates, 2)k = 2),而updates的形状将为(num_updates, 4)inputs.shape[2:] = 4)。请参阅下面的代码

inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras.ops.scatter_update(inputs, indices, updates)

参数

  • inputs:一个张量,要更新的张量。
  • indices:形状为(N, inputs.ndim)的张量或列表/元组,指定要更新的索引。N是要更新的索引数,必须等于updates的第一维。
  • updates:一个张量,要放在inputsindices处的新值。

返回值

一个张量,与inputs具有相同的形状和数据类型。


[源代码]

segment_max 函数

keras.ops.segment_max(data, segment_ids, num_segments=None, sorted=False)

计算张量中段的最大值。

参数

  • data:输入张量。
  • segment_ids:一个N维张量,包含data中每个元素的段索引。data.shape[:len(segment_ids.shape)] 应该匹配。
  • num_segments:表示段总数的整数。如果未指定,则从segment_ids中的最大值推断得出。
  • sorted:一个布尔值,指示segment_ids是否已排序。默认为False

返回值

一个包含段最大值的张量,其中每个元素表示data中相应段的最大值。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_max(data, segment_ids, num_segments)
array([2, 20, 200], dtype=int32)

[源代码]

segment_sum 函数

keras.ops.segment_sum(data, segment_ids, num_segments=None, sorted=False)

计算张量中段的总和。

参数

  • data:输入张量。
  • segment_ids:一个N维张量,包含data中每个元素的段索引。段ID的Num dims应该严格小于或等于data中的dims数量。
  • num_segments:表示段总数的整数。如果未指定,则从segment_ids中的最大值推断得出。
  • sorted:一个布尔值,指示segment_ids是否已排序。默认为False

返回值

一个包含段总和的张量,其中每个元素表示data中相应段的总和。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_sum(data, segment_ids,num_segments)
array([3, 30, 300], dtype=int32)

[源代码]

shape 函数

keras.ops.shape(x)

获取张量输入的形状。

注意:在TensorFlow后端,当x是一个形状动态的tf.Tensor时,在编译函数的上下文中动态的维度将具有tf.Tensor值而不是静态整数值。

参数

  • x:一个张量。此函数将尝试访问输入张量的shape属性。

返回值

一个整数或None值的元组,指示输入张量的形状。

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.shape(x)
(8, 12)

[源代码]

slice 函数

keras.ops.slice(inputs, start_indices, shape)

返回输入张量的切片。

从高级别来看,此操作是数组切片的显式替换,例如inputs[start_indices: start_indices + shape]。与通过方括号进行切片不同,此操作将在所有后端接受张量起始索引,这在通过其他张量操作动态计算索引时非常有用。

inputs = np.zeros((5, 5))
start_indices = np.array([3, 3])
shape = np.array([2, 2])
inputs = keras.ops.slice(inputs, start_indices, shape)

参数

  • inputs:一个张量,要更新的张量。
  • start_indices:形状为(inputs.ndim,)的列表/元组,指定更新的起始索引。
  • shape:返回切片的完整形状。

返回值

一个张量,与inputs具有相同的形状和数据类型。


[源代码]

slice_update 函数

keras.ops.slice_update(inputs, start_indices, updates)

通过在更新值的张量中切片来更新输入。

从高级别来看,此操作执行inputs[start_indices: start_indices + updates.shape] = updates。假设inputs是一个形状为(D0, D1, ..., Dn)的张量,start_indices必须是n个整数的列表/元组,指定起始索引。updates必须与inputs具有相同的秩,并且每个维度的尺寸不得超过Di - start_indices[i]。例如,如果我们有二维输入inputs = np.zeros((5, 5)),并且我们想要将最后2行和最后2列的交集更新为1,即inputs[3:, 3:] = np.ones((2, 2)),那么我们可以使用下面的代码

inputs = np.zeros((5, 5))
start_indices = [3, 3]
updates = np.ones((2, 2))
inputs = keras.ops.slice_update(inputs, start_indices, updates)

参数

  • inputs:一个张量,要更新的张量。
  • start_indices:形状为(inputs.ndim,)的列表/元组,指定更新的起始索引。
  • updates:一个张量,要放在inputsindices处的新值。updates必须与inputs具有相同的秩。

返回值

一个张量,与inputs具有相同的形状和数据类型。


[源代码]

stop_gradient 函数

keras.ops.stop_gradient(variable)

停止梯度计算。

参数

  • variable:要禁用梯度计算的张量变量。

返回值

梯度计算被禁用的变量。

示例

>>> var = keras.backend.convert_to_tensor(
...     [1., 2., 3.],
...     dtype="float32"
... )
>>> var = keras.ops.stop_gradient(var)

[源代码]

switch 函数

keras.ops.switch(index, branches, *operands)

应用由index给出的branches中的一个。

如果index超出范围,则将其钳位到范围内。

switch的语义由以下Python实现大致给出

def switch(index, branches, *operands):
    index = clamp(0, index, len(branches) - 1)
    return branches[index](*operands)

参数

  • index:一个整数标量,指示要应用哪个分支函数。
  • branches:基于index要应用的一系列函数。
  • operands:应用于任何分支的输入。

返回值

根据index选择的分支的branch(*operands)的输出。

示例

>>> add_fn = lambda x, y: x + y
>>> subtract_fn = lambda x, y: x - y
>>> x = keras.ops.array(2.0)
>>> y = keras.ops.array(0.5)
>>> branches = [add_fn, subtract_fn]
>>> keras.ops.switch(0, branches, x, y)
2.5
>>> keras.ops.switch(1, branches, x, y)
1.5

[源代码]

top_k 函数

keras.ops.top_k(x, k, sorted=True)

查找张量中的前k个值及其索引。

参数

  • x:输入张量。
  • k:表示要检索的前k个元素数量的整数。
  • sorted:一个布尔值,指示是否按降序对输出进行排序。默认为True

返回值

包含两个张量的元组。第一个张量包含前k个值,第二个张量包含输入张量中前k个值的索引。

示例

>>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
>>> values, indices = top_k(x, k=3)
>>> print(values)
array([9 7 5], shape=(3,), dtype=int32)
>>> print(indices)
array([4 2 0], shape=(3,), dtype=int32)

[源代码]

unstack 函数

keras.ops.unstack(x, num=None, axis=0)

将秩R张量的给定维度解包成秩(R-1)张量。

参数

  • x:输入张量。
  • num:维度轴的长度。如果为None,则自动推断。
  • axis:要解包的轴。

返回值

沿着给定轴解包的张量列表。

示例

>>> x = keras.ops.array([[1, 2], [3, 4]])
>>> keras.ops.unstack(x, axis=0)
[array([1, 2]), array([3, 4])]

[源代码]

vectorized_map 函数

keras.ops.vectorized_map(function, elements)

张量(s) elements的轴0上function的并行映射。

从示意图上看,vectorized_map实现了以下内容,在单个张量输入elements的情况下

def vectorized_map(function, elements)
    outputs = []
    for e in elements:
        outputs.append(function(e))
    return stack(outputs)

在张量elements的可迭代情况下,它实现了以下内容

def vectorized_map(function, elements)
    batch_size = elements[0].shape[0]
    outputs = []
    for index in range(batch_size):
        outputs.append(function([e[index] for e in elements]))
    return np.stack(outputs)

在这种情况下,function预期将单个张量参数列表作为输入。


[源代码]

while_loop 函数

keras.ops.while_loop(cond, body, loop_vars, maximum_iterations=None)

While循环实现。

参数

  • cond:表示循环终止条件的可调用对象。必须接受类似loop_vars的结构作为参数。如果loop_vars是元组或列表,则loop_vars的每个元素都将按位置传递给可调用对象。
  • body:表示循环体的主体。必须接受类似loop_vars的结构作为参数,并返回具有相同结构的更新值。如果loop_vars是元组或列表,则loop_vars的每个元素都将按位置传递给可调用对象。
  • loop_vars:张量状态的任意嵌套结构,用于在循环迭代中持久化。
  • maximum_iterations:While循环要运行的最大迭代次数(可选)。如果提供,则cond输出将与确保执行的迭代次数不大于maximum_iterations的其他条件进行AND运算。

返回值

张量列表/元组,与inputs具有相同的形状和数据类型。

示例

>>> i = 0
>>> cond = lambda i: i < 10
>>> body = lambda i: i + 1
>>> keras.ops.while_loop(cond, body, i)
10
>>> x, y = 0, 1
>>> cond = lambda x, y: x < 10
>>> body = lambda x, y: (x + 1, y + 1)
>>> keras.ops.while_loop(cond, body, (x, y))
10, 11