In TensorFlow, implementing conditional statements (such as if conditions) is typically achieved using the tf.cond function. The tf.cond function is a control flow operation that accepts a boolean expression and two functions as inputs. Based on the value of the boolean expression (True or False), it executes and returns the result of one of the functions.
Basic syntax of tf.cond:
pythontf.cond( pred, true_fn=None, false_fn=None, name=None )
- pred: A boolean tensor used for conditional evaluation.
- true_fn: The function to execute when
predisTrue. - false_fn: The function to execute when
predisFalse. - name: (Optional) The name of the operation.
Example
Suppose we have a simple task where we determine the operation based on the value of an input tensor: if the value is greater than 0, we multiply it by 2; otherwise, we divide it by 2.
Here is how to implement this logic using tf.cond:
pythonimport tensorflow as tf # Define input tensor x = tf.constant(3.0) # Define conditional operation result = tf.cond( tf.greater(x, 0), true_fn=lambda: x * 2, false_fn=lambda: x / 2 ) # Initialize Session sess = tf.compat.v1.Session() # Compute result print("Result:", sess.run(result)) # Close Session sess.close()
In this example, we use tf.greater to check if x is greater than 0. Since x has a value of 3, tf.greater(x, 0) evaluates to True, so true_fn (i.e., lambda: x * 2) is executed.
This approach is highly valuable when building models, particularly when the model's behavior depends on dynamic conditions, such as varying behavior across different iterations or data subsets.