服务端阅读 06月4日 12:33
TensorFlow 张量怎么创建和操作?constant 和 Variable 有什么区别?
张量就是 TensorFlow 里的多维数组——0 维是标量(一个数),1 维是向量,2 维是矩阵,3 维及以上就是高阶张量。它和 NumPy 的 ndarray 很像,但有两个关键区别:张量可以放在 GPU 上加速计算,张量是计算图中的节点,可以被自动求导。两种张量:constant 和 Variabletf.constant 创建不可变张量,值一旦设定不能改。tf.Variable 创建可变张量,可以通过 .assign()、.assign_add() 修改。模型权重用 Variable,输入数据用 constant——这是最核心的区分:训练过程中需要更新的参数必须是 Variable,否则梯度无法回传。# constant:不可变,用于输入/超参x = tf.constant([[1.0, 2.0], [3.0, 4.0]])# Variable:可变,用于模型权重w = tf.Variable(tf.random.normal([2, 2]))w.assign_add(tf.ones([2, 2])) # 原地加1常用创建方式初始化模型权重时最常用的三种:tf.random.normal(正态分布,适合全连接层)、tf.random.uniform(均匀分布,适合某些初始化策略)、tf.zeros(偏置项常用零初始化)。从已有数据创建用 tf.constant 或 tf.convert_to_tensor(自动把 NumPy 数组/Python 列表转成张量)。w1 = tf.random.normal([3, 128], stddev=0.02) # 权重:正态分布b1 = tf.zeros([128]) # 偏置:零初始化w2 = tf.Variable(tf.random.uniform([128, 10], -0.1, 0.1))最容易踩的坑:数据类型和形状类型陷阱:tf.constant([1, 2, 3]) 默认是 int32,做除法 a / b 会出错(整数除法不是浮点除法)。养成习惯:涉及计算的张量显式指定 dtype=tf.float32。形状操作:tf.reshape 不改变数据只是重新切分维度,tf.expand_dims 加一个大小为 1 的维度(常用于 batch 维度),tf.squeeze 去掉大小为 1 的维度。记住 reshape 前后元素总数必须一致。x = tf.constant([1, 2, 3, 4, 5, 6]) # shape (6,)x = tf.reshape(x, [2, 3]) # shape (2, 3)x = tf.expand_dims(x, 0) # shape (1, 2, 3) — 加 batch 维x = tf.squeeze(x) # shape (2, 3) — 去掉 size-1 维广播机制不同形状的张量做运算时,TensorFlow 自动把小形状"广播"到大形状:(2, 3) + (3,) → 先把 (3,) 复制两行变成 (2, 3) 再相加。这和 NumPy 的广播规则完全一致。常见场景:给矩阵的每一行加一个偏置 (batch, dim) + (dim,)。追问tf.Tensor 和 NumPy ndarray 怎么互转?tensor.numpy() 把张量转成 NumPy 数组(Eager 模式下),tf.convert_to_tensor(np_array) 反向转换。注意:GPU 上的张量转 NumPy 会触发设备同步(数据从 GPU 拷回 CPU),频繁调用会拖慢速度。在训练循环里尽量避免这种转换。张量的 rank、shape、axis 怎么理解?rank 是维度数(scalar=0, vector=1, matrix=2),shape 是每个维度的大小(如 [3, 4] 表示 3 行 4 列),axis 是对哪个维度操作(axis=0 沿行方向,axis=1 沿列方向)。tf.reduce_mean(x, axis=0) 就是"对每一列求平均"——消掉第 0 维,结果从 (3,4) 变成 (4,)。TensorFlow 张量和 PyTorch 张量有什么区别?核心概念一样,API 略有不同:TF 用 tf.reshape,PyTorch 用 torch.reshape 或 .view();TF 的张量默认放在 CPU,需要 with tf.device 指定 GPU,PyTorch 用 .to('cuda')。最大的区别是 TF 张量有 tf.Variable 的概念(可变 vs 不可变),PyTorch 所有张量都可变,通过 requires_grad=True 控制是否求导。