Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 166 additions & 74 deletions src/cvxpylayers/interfaces/diffcp_if.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import diffcp
import numpy as np
Expand All @@ -16,6 +17,85 @@
except ImportError:
torch = None # type: ignore[assignment]

if TYPE_CHECKING:
# Type alias for multi-framework tensor types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart. I like it.

TensorLike = torch.Tensor | jnp.ndarray | np.ndarray
else:
TensorLike = Any


def _detect_batch_size(con_values: TensorLike) -> tuple[int, bool]:
"""Detect batch size and whether input was originally unbatched.

Handles both PyTorch tensors and JAX arrays by checking the number
of dimensions.

Args:
con_values: Constraint values (torch.Tensor or jnp.ndarray)

Returns:
Tuple of (batch_size, originally_unbatched) where:
- batch_size: Number of batch elements (1 if unbatched)
- originally_unbatched: True if input had no batch dimension
"""
# Handle both torch tensors (.dim()) and jax/numpy arrays (.ndim)
ndim = con_values.dim() if hasattr(con_values, "dim") else con_values.ndim # type: ignore[attr-defined]

if ndim == 1:
return 1, True # Unbatched input
else:
return con_values.shape[1], False # Batched input


def _build_diffcp_matrices(
con_values: TensorLike,
lin_obj_values: TensorLike,
A_structure: tuple[np.ndarray, np.ndarray],
A_shape: tuple[int, int],
b_idx: np.ndarray,
batch_size: int,
) -> tuple[list[sp.csc_matrix], list[np.ndarray], list[np.ndarray], list[np.ndarray]]:
"""Build DIFFCP matrices from constraint and objective values.

Converts parameter values into the conic form required by DIFFCP solver:
minimize c^T x subject to Ax + s = b, s in K
where K is a product of cones.

Args:
con_values: Constraint coefficient values (batched)
lin_obj_values: Linear objective coefficient values (batched)
A_structure: Sparse matrix structure (indices, indptr)
A_shape: Shape of augmented constraint matrix
b_idx: Indices for extracting RHS from last column
batch_size: Number of batch elements

Returns:
Tuple of (As, bs, cs, b_idxs) where:
- As: List of constraint matrices (one per batch element)
- bs: List of RHS vectors (one per batch element)
- cs: List of cost vectors (one per batch element)
- b_idxs: List of RHS index arrays (one per batch element)
"""
As, bs, cs, b_idxs = [], [], [], []

for i in range(batch_size):
# Convert to numpy - handles both torch tensors and jax arrays
con_vals_i = np.array(con_values[:, i])
lin_vals_i = np.array(lin_obj_values[:-1, i])

# Build augmented matrix [A | b] from sparse structure
A_aug = sp.csc_matrix(
(con_vals_i, *A_structure),
shape=A_shape,
)
# Extract A and b, negating A to match DIFFCP convention
As.append(-A_aug[:, :-1])
bs.append(A_aug[:, -1].toarray().flatten())
cs.append(lin_vals_i)
b_idxs.append(b_idx)

return As, bs, cs, b_idxs


class DIFFCP_ctx:
c_slice: slice
Expand Down Expand Up @@ -50,28 +130,22 @@ def __init__(
self.dims = dims

def torch_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_data":
# Detect batch size and whether input was originally unbatched
if con_values.dim() == 1:
originally_unbatched = True
batch_size = 1
# Add batch dimension for uniform handling
batch_size, originally_unbatched = _detect_batch_size(con_values)

# Add batch dimension for uniform handling if needed
if originally_unbatched:
con_values = con_values.unsqueeze(1)
lin_obj_values = lin_obj_values.unsqueeze(1)
else:
originally_unbatched = False
batch_size = con_values.shape[1]

# Build lists for all batch elements
As, bs, cs, b_idxs = [], [], [], []
for i in range(batch_size):
A_aug = sp.csc_matrix(
(con_values[:, i].cpu().numpy(), *self.A_structure),
shape=self.A_shape,
)
As.append(-A_aug[:, :-1]) # Negate A to match DIFFCP convention
bs.append(A_aug[:, -1].toarray().flatten())
cs.append(lin_obj_values[:-1, i].cpu().numpy())
b_idxs.append(self.b_idx)

# Build matrices
As, bs, cs, b_idxs = _build_diffcp_matrices(
con_values,
lin_obj_values,
self.A_structure,
self.A_shape,
self.b_idx,
batch_size,
)

return DIFFCP_data(
As=As,
Expand All @@ -89,28 +163,23 @@ def jax_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_da
"JAX interface requires 'jax' package to be installed. "
"Install with: pip install jax"
)
# Detect batch size and whether input was originally unbatched
if con_values.ndim == 1:
originally_unbatched = True
batch_size = 1
# Add batch dimension for uniform handling

batch_size, originally_unbatched = _detect_batch_size(con_values)

# Add batch dimension for uniform handling if needed
if originally_unbatched:
con_values = jnp.expand_dims(con_values, 1)
lin_obj_values = jnp.expand_dims(lin_obj_values, 1)
else:
originally_unbatched = False
batch_size = con_values.shape[1]

# Build lists for all batch elements
As, bs, cs, b_idxs = [], [], [], []
for i in range(batch_size):
A_aug = sp.csc_matrix(
(np.array(con_values[:, i]), *self.A_structure),
shape=self.A_shape,
)
As.append(-A_aug[:, :-1]) # Negate A to match DIFFCP convention
bs.append(A_aug[:, -1].toarray().flatten())
cs.append(np.array(lin_obj_values[:-1, i]))
b_idxs.append(self.b_idx)

# Build matrices
As, bs, cs, b_idxs = _build_diffcp_matrices(
con_values,
lin_obj_values,
self.A_structure,
self.A_shape,
self.b_idx,
batch_size,
)

return DIFFCP_data(
As=As,
Expand All @@ -123,6 +192,55 @@ def jax_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_da
)


def _compute_gradients(
adj_batch: Callable,
dprimal: TensorLike,
ddual: TensorLike,
bs: list[np.ndarray],
b_idxs: list[np.ndarray],
batch_size: int,
) -> tuple[list[np.ndarray], list[np.ndarray]]:
"""Compute gradients using DIFFCP's adjoint method.

Uses implicit differentiation to compute gradients of the optimization
solution with respect to problem parameters. The adjoint method efficiently
computes these gradients by solving the adjoint system.

Args:
adj_batch: DIFFCP's batch adjoint function
dprimal: Incoming gradients w.r.t. primal solution
ddual: Incoming gradients w.r.t. dual solution
bs: List of RHS vectors from forward pass
b_idxs: List of RHS indices from forward pass
batch_size: Number of batch elements

Returns:
Tuple of (dq_batch, dA_batch) where:
- dq_batch: List of gradients w.r.t. linear objective coefficients
- dA_batch: List of gradients w.r.t. constraint coefficients
"""
# Convert incoming gradients to lists for DIFFCP
dxs = [np.array(dprimal[i]) for i in range(batch_size)]
dys = [np.array(ddual[i]) for i in range(batch_size)]
dss = [np.zeros_like(bs[i]) for i in range(batch_size)] # No gradient w.r.t. slack

# Call DIFFCP's batch adjoint to get gradients w.r.t. problem data
dAs, dbs, dcs = adj_batch(dxs, dys, dss)

# Aggregate gradients from each batch element
dq_batch = []
dA_batch = []
for i in range(batch_size):
# Negate dA because A was negated in forward pass, but not db (b was not negated)
con_grad = np.hstack([-dAs[i].data, dbs[i][b_idxs[i]]])
# Add zero gradient for constant offset term
lin_grad = np.hstack([dcs[i], np.array([0.0])])
dA_batch.append(con_grad)
dq_batch.append(lin_grad)

return dq_batch, dA_batch


@dataclass
class DIFFCP_data:
As: list[sp.csc_matrix]
Expand Down Expand Up @@ -161,23 +279,10 @@ def torch_derivative(self, primal, dual, adj_batch):
"PyTorch interface requires 'torch' package. Install with: pip install torch"
)

# Split batched tensors into lists
dxs = [primal[i].numpy() for i in range(self.batch_size)]
dys = [dual[i].numpy() for i in range(self.batch_size)]
dss = [np.zeros_like(self.bs[i]) for i in range(self.batch_size)]

# Call batch adjoint
dAs, dbs, dcs = adj_batch(dxs, dys, dss)

# Aggregate gradients from each batch element
dq_batch = []
dA_batch = []
for i in range(self.batch_size):
# Negate dA because A was negated in forward pass, but not db (b was not negated)
con_grad = np.hstack([-dAs[i].data, dbs[i][self.b_idxs[i]]])
lin_grad = np.hstack([dcs[i], np.array([0.0])])
dA_batch.append(con_grad)
dq_batch.append(lin_grad)
# Compute gradients
dq_batch, dA_batch = _compute_gradients(
adj_batch, primal, dual, self.bs, self.b_idxs, self.batch_size
)

# Stack into shape (num_entries, batch_size)
dq_stacked = torch.stack([torch.from_numpy(g) for g in dq_batch]).T
Expand Down Expand Up @@ -216,23 +321,10 @@ def jax_solve(self, solver_args=None):
return primal, dual, adj_batch

def jax_derivative(self, dprimal, ddual, adj_batch):
# Split batched arrays into lists
dxs = [np.array(dprimal[i]) for i in range(self.batch_size)]
dys = [np.array(ddual[i]) for i in range(self.batch_size)]
dss = [np.zeros_like(self.bs[i]) for i in range(self.batch_size)]

# Call batch adjoint
dAs, dbs, dcs = adj_batch(dxs, dys, dss)

# Aggregate gradients from each batch element
dq_batch = []
dA_batch = []
for i in range(self.batch_size):
# Negate dA because A was negated in forward pass, but not db (b was not negated)
con_grad = np.hstack([-dAs[i].data, dbs[i][self.b_idxs[i]]])
lin_grad = np.hstack([dcs[i], np.array([0.0])])
dA_batch.append(con_grad)
dq_batch.append(lin_grad)
# Compute gradients
dq_batch, dA_batch = _compute_gradients(
adj_batch, dprimal, ddual, self.bs, self.b_idxs, self.batch_size
)

# Stack into shape (num_entries, batch_size)
dq_stacked = jnp.stack([jnp.array(g) for g in dq_batch]).T
Expand Down
Loading