Skip to content

Commit 7a57a4c

Browse files
committed
docs: improve conv_transpose clarity by documenting the diff spatial axes bw tensorflow and lax.conv_transpose
1 parent c83a719 commit 7a57a4c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

jax/_src/lax/convolution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
294294
This function directly calculates a fractionally strided conv rather than
295295
indirectly calculating the gradient (transpose) of a forward convolution.
296296
297+
Notes:
298+
TensorFlow/Keras Compatibility: By default, JAX does NOT reverse the
299+
kernel's spatial dimensions. This differs from TensorFlow's "Conv2DTranspose"
300+
and similar frameworks, which flip spatial axes and swap input/output channels.
301+
302+
To match TensorFlow/Keras behavior, set "transpose_kernel=True" .
303+
297304
Args:
298305
lhs: a rank `n+2` dimensional input array.
299306
rhs: a rank `n+2` dimensional array of kernel weights.

0 commit comments

Comments
 (0)