|
47 | 47 | split_dims, |
48 | 48 | squeeze, |
49 | 49 | to_one_hot, |
| 50 | + unpack, |
50 | 51 | unravel_index, |
51 | 52 | ) |
52 | 53 | from pytensor.tensor.type import ( |
@@ -1593,3 +1594,50 @@ def test_pack_basic(self): |
1593 | 1594 | packed_shape.eval(input_dict, on_unused_input="ignore"), |
1594 | 1595 | tensor.type.shape[1:-1], |
1595 | 1596 | ) |
| 1597 | + |
| 1598 | + @pytest.mark.parametrize("axes", [None, -1, (-2, -1)]) |
| 1599 | + def test_pack_unpack_round_trip(self, axes): |
| 1600 | + rng = np.random.default_rng() |
| 1601 | + |
| 1602 | + x = pt.tensor("x", shape=(3, 5)) |
| 1603 | + y = pt.tensor("y", shape=(3, 3, 5)) |
| 1604 | + z = pt.tensor("z", shape=(1, 3, 5)) |
| 1605 | + |
| 1606 | + flat_packed, packed_shapes = pack(x, y, z, axes=axes) |
| 1607 | + new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes) |
| 1608 | + |
| 1609 | + fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") |
| 1610 | + |
| 1611 | + input_dict = { |
| 1612 | + var.name: rng.normal(size=var.type.shape).astype(config.floatX) |
| 1613 | + for var in [x, y, z] |
| 1614 | + } |
| 1615 | + output_vals = fn(**input_dict) |
| 1616 | + |
| 1617 | + for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): |
| 1618 | + np.testing.assert_allclose(input_val, output_val) |
| 1619 | + |
| 1620 | + |
| 1621 | +def test_make_replacements_with_pack_unpack(): |
| 1622 | + rng = np.random.default_rng() |
| 1623 | + |
| 1624 | + x = pt.tensor("x", shape=()) |
| 1625 | + y = pt.tensor("y", shape=(5,)) |
| 1626 | + z = pt.tensor("z", shape=(3, 3)) |
| 1627 | + |
| 1628 | + loss = (x + y.sum() + z.sum()) ** 2 |
| 1629 | + |
| 1630 | + flat_packed, packed_shapes = pack(x, y, z, axes=None) |
| 1631 | + new_input = flat_packed.type() |
| 1632 | + new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes) |
| 1633 | + |
| 1634 | + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) |
| 1635 | + fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") |
| 1636 | + |
| 1637 | + input_vals = [ |
| 1638 | + rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] |
| 1639 | + ] |
| 1640 | + flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) |
| 1641 | + output_val = fn(flat_inputs, *input_vals) |
| 1642 | + |
| 1643 | + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2) |
0 commit comments