Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ci/run_bazel_test_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ else
//tests/pallas:tpu_pallas_test_tpu \
//tests/pallas:tpu_pallas_call_print_test_tpu \
//tests/pallas:indexing_test_tpu \
//tests/pallas:pallas_cost_estimate_test_tpu \
//tests/pallas:pallas_error_handling_test_tpu \
//tests/pallas:pallas_jumble_test_tpu \
//tests/pallas:pallas_shape_poly_test_tpu \
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/pallas/cost_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def dot_general_cost_rule(ctx: Context,
assert len(lhs_batch_dims) == len(rhs_batch_dims)
flops = 1
# Flops along a contracting dim is 2*dim (addition and multiplication)
contracting_flops = 1
for i in range(len(lhs_contracting_dims)):
lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i]
assert x_shape[lhs_dim] == y_shape[rhs_dim]
flops *= 2 * x_shape[lhs_dim]
contracting_flops *= x_shape[lhs_dim]
flops *= 2 * contracting_flops
# Now we handle all other dimensions.
for i, lhs_dim in enumerate(x_shape):
if i in lhs_contracting_dims:
Expand Down
8 changes: 3 additions & 5 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,15 @@ jax_multiplatform_test(
]),
)

jax_multiplatform_test(
jax_py_test(
name = "pallas_cost_estimate_test",
srcs = [
"pallas_cost_estimate_test.py",
],
args = ["--jax_test_dut=cpu"],
deps = [
"//jax:pallas",
"//jax:pallas_gpu",
"//jax:pallas_gpu_ops",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src:test_util",
] + py_deps([
"absl/testing",
"numpy",
Expand Down
17 changes: 17 additions & 0 deletions tests/pallas/pallas_cost_estimate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def matmul(a, b):
self.assertEqual(cost.transcendentals, 0)
self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n))

@parameterized.parameters(
((10, 11, 12), (11, 12), "abc,bc->a"),
((10, 11, 12), (13, 11, 12), "abc,dbc->ad"),
((10, 11, 12), (9, 10, 11, 12), "abc,dabc->d"),
)
def test_einsum(self, a_shape, b_shape, pattern):
a = jnp.ones(a_shape, dtype=jnp.float32)
b = jnp.ones(b_shape, dtype=jnp.float32)
def matmul(a, b):
return jnp.einsum(pattern, a, b)
cost = cost_estimate.estimate_cost(
matmul,
jax.ShapeDtypeStruct(a_shape, jnp.float32),
jax.ShapeDtypeStruct(b_shape, jnp.float32))
xla_flops = jax.jit(matmul).lower(a, b).compile().cost_analysis()['flops']
self.assertEqual(cost.flops, int(xla_flops))

def test_attention(self):
qk_dim = 16
v_dim = 4
Expand Down
Loading