Skip to content

Commit 4d638b7

Browse files
Merge pull request #33419 from Prakharprasun:doc-lax-conv-transpose
PiperOrigin-RevId: 839829550
2 parents 3b5c07c + 7a57a4c commit 4d638b7

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

jax/_src/lax/convolution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
302302
This function directly calculates a fractionally strided conv rather than
303303
indirectly calculating the gradient (transpose) of a forward convolution.
304304
305+
Notes:
306+
TensorFlow/Keras Compatibility: By default, JAX does NOT reverse the
307+
kernel's spatial dimensions. This differs from TensorFlow's "Conv2DTranspose"
308+
and similar frameworks, which flip spatial axes and swap input/output channels.
309+
310+
To match TensorFlow/Keras behavior, set "transpose_kernel=True" .
311+
305312
Args:
306313
lhs: a rank `n+2` dimensional input array.
307314
rhs: a rank `n+2` dimensional array of kernel weights.

0 commit comments

Comments
 (0)