diff --git a/src/cvxpylayers/interfaces/diffcp_if.py b/src/cvxpylayers/interfaces/diffcp_if.py index 20ac6cc..e3602a3 100644 --- a/src/cvxpylayers/interfaces/diffcp_if.py +++ b/src/cvxpylayers/interfaces/diffcp_if.py @@ -1,5 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass +from typing import TYPE_CHECKING, Any import diffcp import numpy as np @@ -16,6 +17,85 @@ except ImportError: torch = None # type: ignore[assignment] +if TYPE_CHECKING: + # Type alias for multi-framework tensor types + TensorLike = torch.Tensor | jnp.ndarray | np.ndarray +else: + TensorLike = Any + + +def _detect_batch_size(con_values: TensorLike) -> tuple[int, bool]: + """Detect batch size and whether input was originally unbatched. + + Handles both PyTorch tensors and JAX arrays by checking the number + of dimensions. + + Args: + con_values: Constraint values (torch.Tensor or jnp.ndarray) + + Returns: + Tuple of (batch_size, originally_unbatched) where: + - batch_size: Number of batch elements (1 if unbatched) + - originally_unbatched: True if input had no batch dimension + """ + # Handle both torch tensors (.dim()) and jax/numpy arrays (.ndim) + ndim = con_values.dim() if hasattr(con_values, "dim") else con_values.ndim # type: ignore[attr-defined] + + if ndim == 1: + return 1, True # Unbatched input + else: + return con_values.shape[1], False # Batched input + + +def _build_diffcp_matrices( + con_values: TensorLike, + lin_obj_values: TensorLike, + A_structure: tuple[np.ndarray, np.ndarray], + A_shape: tuple[int, int], + b_idx: np.ndarray, + batch_size: int, +) -> tuple[list[sp.csc_matrix], list[np.ndarray], list[np.ndarray], list[np.ndarray]]: + """Build DIFFCP matrices from constraint and objective values. + + Converts parameter values into the conic form required by DIFFCP solver: + minimize c^T x subject to Ax + s = b, s in K + where K is a product of cones. + + Args: + con_values: Constraint coefficient values (batched) + lin_obj_values: Linear objective coefficient values (batched) + A_structure: Sparse matrix structure (indices, indptr) + A_shape: Shape of augmented constraint matrix + b_idx: Indices for extracting RHS from last column + batch_size: Number of batch elements + + Returns: + Tuple of (As, bs, cs, b_idxs) where: + - As: List of constraint matrices (one per batch element) + - bs: List of RHS vectors (one per batch element) + - cs: List of cost vectors (one per batch element) + - b_idxs: List of RHS index arrays (one per batch element) + """ + As, bs, cs, b_idxs = [], [], [], [] + + for i in range(batch_size): + # Convert to numpy - handles both torch tensors and jax arrays + con_vals_i = np.array(con_values[:, i]) + lin_vals_i = np.array(lin_obj_values[:-1, i]) + + # Build augmented matrix [A | b] from sparse structure + A_aug = sp.csc_matrix( + (con_vals_i, *A_structure), + shape=A_shape, + ) + # Extract A and b, negating A to match DIFFCP convention + As.append(-A_aug[:, :-1]) + bs.append(A_aug[:, -1].toarray().flatten()) + cs.append(lin_vals_i) + b_idxs.append(b_idx) + + return As, bs, cs, b_idxs + class DIFFCP_ctx: c_slice: slice @@ -50,28 +130,22 @@ def __init__( self.dims = dims def torch_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_data": - # Detect batch size and whether input was originally unbatched - if con_values.dim() == 1: - originally_unbatched = True - batch_size = 1 - # Add batch dimension for uniform handling + batch_size, originally_unbatched = _detect_batch_size(con_values) + + # Add batch dimension for uniform handling if needed + if originally_unbatched: con_values = con_values.unsqueeze(1) lin_obj_values = lin_obj_values.unsqueeze(1) - else: - originally_unbatched = False - batch_size = con_values.shape[1] - - # Build lists for all batch elements - As, bs, cs, b_idxs = [], [], [], [] - for i in range(batch_size): - A_aug = sp.csc_matrix( - (con_values[:, i].cpu().numpy(), *self.A_structure), - shape=self.A_shape, - ) - As.append(-A_aug[:, :-1]) # Negate A to match DIFFCP convention - bs.append(A_aug[:, -1].toarray().flatten()) - cs.append(lin_obj_values[:-1, i].cpu().numpy()) - b_idxs.append(self.b_idx) + + # Build matrices + As, bs, cs, b_idxs = _build_diffcp_matrices( + con_values, + lin_obj_values, + self.A_structure, + self.A_shape, + self.b_idx, + batch_size, + ) return DIFFCP_data( As=As, @@ -89,28 +163,23 @@ def jax_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_da "JAX interface requires 'jax' package to be installed. " "Install with: pip install jax" ) - # Detect batch size and whether input was originally unbatched - if con_values.ndim == 1: - originally_unbatched = True - batch_size = 1 - # Add batch dimension for uniform handling + + batch_size, originally_unbatched = _detect_batch_size(con_values) + + # Add batch dimension for uniform handling if needed + if originally_unbatched: con_values = jnp.expand_dims(con_values, 1) lin_obj_values = jnp.expand_dims(lin_obj_values, 1) - else: - originally_unbatched = False - batch_size = con_values.shape[1] - - # Build lists for all batch elements - As, bs, cs, b_idxs = [], [], [], [] - for i in range(batch_size): - A_aug = sp.csc_matrix( - (np.array(con_values[:, i]), *self.A_structure), - shape=self.A_shape, - ) - As.append(-A_aug[:, :-1]) # Negate A to match DIFFCP convention - bs.append(A_aug[:, -1].toarray().flatten()) - cs.append(np.array(lin_obj_values[:-1, i])) - b_idxs.append(self.b_idx) + + # Build matrices + As, bs, cs, b_idxs = _build_diffcp_matrices( + con_values, + lin_obj_values, + self.A_structure, + self.A_shape, + self.b_idx, + batch_size, + ) return DIFFCP_data( As=As, @@ -123,6 +192,55 @@ def jax_to_data(self, quad_obj_values, lin_obj_values, con_values) -> "DIFFCP_da ) +def _compute_gradients( + adj_batch: Callable, + dprimal: TensorLike, + ddual: TensorLike, + bs: list[np.ndarray], + b_idxs: list[np.ndarray], + batch_size: int, +) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Compute gradients using DIFFCP's adjoint method. + + Uses implicit differentiation to compute gradients of the optimization + solution with respect to problem parameters. The adjoint method efficiently + computes these gradients by solving the adjoint system. + + Args: + adj_batch: DIFFCP's batch adjoint function + dprimal: Incoming gradients w.r.t. primal solution + ddual: Incoming gradients w.r.t. dual solution + bs: List of RHS vectors from forward pass + b_idxs: List of RHS indices from forward pass + batch_size: Number of batch elements + + Returns: + Tuple of (dq_batch, dA_batch) where: + - dq_batch: List of gradients w.r.t. linear objective coefficients + - dA_batch: List of gradients w.r.t. constraint coefficients + """ + # Convert incoming gradients to lists for DIFFCP + dxs = [np.array(dprimal[i]) for i in range(batch_size)] + dys = [np.array(ddual[i]) for i in range(batch_size)] + dss = [np.zeros_like(bs[i]) for i in range(batch_size)] # No gradient w.r.t. slack + + # Call DIFFCP's batch adjoint to get gradients w.r.t. problem data + dAs, dbs, dcs = adj_batch(dxs, dys, dss) + + # Aggregate gradients from each batch element + dq_batch = [] + dA_batch = [] + for i in range(batch_size): + # Negate dA because A was negated in forward pass, but not db (b was not negated) + con_grad = np.hstack([-dAs[i].data, dbs[i][b_idxs[i]]]) + # Add zero gradient for constant offset term + lin_grad = np.hstack([dcs[i], np.array([0.0])]) + dA_batch.append(con_grad) + dq_batch.append(lin_grad) + + return dq_batch, dA_batch + + @dataclass class DIFFCP_data: As: list[sp.csc_matrix] @@ -161,23 +279,10 @@ def torch_derivative(self, primal, dual, adj_batch): "PyTorch interface requires 'torch' package. Install with: pip install torch" ) - # Split batched tensors into lists - dxs = [primal[i].numpy() for i in range(self.batch_size)] - dys = [dual[i].numpy() for i in range(self.batch_size)] - dss = [np.zeros_like(self.bs[i]) for i in range(self.batch_size)] - - # Call batch adjoint - dAs, dbs, dcs = adj_batch(dxs, dys, dss) - - # Aggregate gradients from each batch element - dq_batch = [] - dA_batch = [] - for i in range(self.batch_size): - # Negate dA because A was negated in forward pass, but not db (b was not negated) - con_grad = np.hstack([-dAs[i].data, dbs[i][self.b_idxs[i]]]) - lin_grad = np.hstack([dcs[i], np.array([0.0])]) - dA_batch.append(con_grad) - dq_batch.append(lin_grad) + # Compute gradients + dq_batch, dA_batch = _compute_gradients( + adj_batch, primal, dual, self.bs, self.b_idxs, self.batch_size + ) # Stack into shape (num_entries, batch_size) dq_stacked = torch.stack([torch.from_numpy(g) for g in dq_batch]).T @@ -215,23 +320,10 @@ def jax_solve(self, solver_args=None): return primal, dual, adj_batch def jax_derivative(self, dprimal, ddual, adj_batch): - # Split batched arrays into lists - dxs = [np.array(dprimal[i]) for i in range(self.batch_size)] - dys = [np.array(ddual[i]) for i in range(self.batch_size)] - dss = [np.zeros_like(self.bs[i]) for i in range(self.batch_size)] - - # Call batch adjoint - dAs, dbs, dcs = adj_batch(dxs, dys, dss) - - # Aggregate gradients from each batch element - dq_batch = [] - dA_batch = [] - for i in range(self.batch_size): - # Negate dA because A was negated in forward pass, but not db (b was not negated) - con_grad = np.hstack([-dAs[i].data, dbs[i][self.b_idxs[i]]]) - lin_grad = np.hstack([dcs[i], np.array([0.0])]) - dA_batch.append(con_grad) - dq_batch.append(lin_grad) + # Compute gradients + dq_batch, dA_batch = _compute_gradients( + adj_batch, dprimal, ddual, self.bs, self.b_idxs, self.batch_size + ) # Stack into shape (num_entries, batch_size) dq_stacked = jnp.stack([jnp.array(g) for g in dq_batch]).T diff --git a/src/cvxpylayers/interfaces/mpax_if.py b/src/cvxpylayers/interfaces/mpax_if.py index 69cd38a..c9cd2d9 100644 --- a/src/cvxpylayers/interfaces/mpax_if.py +++ b/src/cvxpylayers/interfaces/mpax_if.py @@ -1,5 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass +from typing import Any import numpy as np import scipy.sparse as sp @@ -20,6 +21,75 @@ torch = None # type: ignore[assignment] +def _parse_objective_structure( + objective_structure: tuple, +) -> tuple[slice, np.ndarray, tuple[np.ndarray, np.ndarray], tuple[int, int], int]: + """Parse objective structure to extract quadratic (Q) matrix components. + + Converts CVXPY's canonical objective structure into sparse matrix components + for the quadratic cost matrix Q in the QP formulation. + + Args: + objective_structure: Tuple of (indices, indptr, (n, n)) from CVXPY + + Returns: + Tuple of (c_slice, Q_idxs, Q_structure, Q_shape, n) where: + - c_slice: Slice for linear cost vector c + - Q_idxs: Data indices for Q matrix values + - Q_structure: (indices, indptr) for Q sparse structure + - Q_shape: Shape (n, n) of Q matrix + - n: Number of primal variables + """ + obj_indices, obj_ptr, (n, _) = objective_structure + c_slice = slice(0, n) + + # Convert to CSR format for efficient row access + obj_csr = sp.csc_array( + (np.arange(obj_indices.size), obj_indices, obj_ptr), + shape=(n, n), + ).tocsr() + + Q_idxs = obj_csr.data + Q_structure = obj_csr.indices, obj_csr.indptr + Q_shape = (n, n) + + return c_slice, Q_idxs, Q_structure, Q_shape, n + + +def _initialize_solver(options: dict[str, Any] | None) -> tuple[Callable, bool]: + """Initialize MPAX solver based on options. + + Args: + options: Solver options dictionary containing: + - warm_start: Whether to use warm starting (currently must be False) + - algorithm: "raPDHG" or "r2HPDHG" + - Additional solver-specific options + + Returns: + Tuple of (jitted_solver_fn, warm_start_flag) + + Raises: + ValueError: If algorithm is not "raPDHG" or "r2HPDHG" + """ + if options is None: + options = {} + + warm_start = options.pop("warm_start", False) + assert warm_start is False + + algorithm = options.pop("algorithm", "raPDHG") + + if algorithm == "raPDHG": + alg = mpax.raPDHG + elif algorithm == "r2HPDHG": + alg = mpax.r2HPDHG + else: + raise ValueError("Invalid MPAX algorithm") + + solver = alg(warm_start=warm_start, **options) + return jax.jit(solver.optimize), warm_start + + class MPAX_ctx: Q_idxs: jnp.ndarray c_slice: slice @@ -55,20 +125,17 @@ def __init__( "MPAX solver requires 'mpax' and 'jax' packages to be installed. " "Install with: pip install mpax jax" ) - obj_indices, obj_ptr, (n, _) = objective_structure - self.c_slice = slice(0, n) - obj_csr = sp.csc_array( - (np.arange(obj_indices.size), obj_indices, obj_ptr), - shape=(n, n), - ).tocsr() - self.Q_idxs = obj_csr.data - self.Q_structure = obj_csr.indices, obj_csr.indptr - self.Q_shape = (n, n) + # Parse objective structure + self.c_slice, self.Q_idxs, self.Q_structure, self.Q_shape, n = _parse_objective_structure( + objective_structure + ) + + # Parse constraint structure - splits into equality (A) and inequality (G) matrices con_indices, con_ptr, (m, np1) = constraint_structure assert np1 == n + 1 - # Extract indices for the last column (which contains b and h values) + # Extract indices for the last column (which contains b and h RHS values) # Use indices instead of slices because sparse matrices may have reduced out # explicit zeros, so we need to reconstruct the full dense vectors self.last_col_start = con_ptr[-2] @@ -76,40 +143,32 @@ def __init__( self.last_col_indices = con_indices[self.last_col_start : self.last_col_end] self.m = m # Total number of constraint rows + # Convert to CSR format for row-based splitting con_csr = sp.csc_array( (np.arange(con_indices.size), con_indices, con_ptr[:-1]), shape=(m, n), ).tocsr() - split = con_csr.indptr[dims.zero] + split = con_csr.indptr[dims.zero] # Split point between equality and inequality + # Extract equality constraints (A) self.A_idxs = con_csr.data[:split] self.A_structure = con_csr.indices[:split], con_csr.indptr[: dims.zero + 1] self.A_shape = (dims.zero, n) + # Extract inequality constraints (G) self.G_idxs = con_csr.data[split:] self.G_structure = con_csr.indices[split:], con_csr.indptr[dims.zero :] - split self.G_shape = (m - dims.zero, n) - self.lower = lower_bounds if lower_bounds is not None else -jnp.inf * jnp.ones(n) - self.upper = upper_bounds if upper_bounds is not None else jnp.inf * jnp.ones(n) - # Precompute split_at to avoid binary search on every solve self.split_at = int(jnp.searchsorted(self.last_col_indices, dims.zero)) - if options is None: - options = {} - self.warm_start = options.pop("warm_start", False) - assert self.warm_start is False - algorithm = options.pop("algorithm", "raPDHG") + # Set bounds + self.lower = lower_bounds if lower_bounds is not None else -jnp.inf * jnp.ones(n) + self.upper = upper_bounds if upper_bounds is not None else jnp.inf * jnp.ones(n) - if algorithm == "raPDHG": - alg = mpax.raPDHG - elif algorithm == "r2HPDHG": - alg = mpax.r2HPDHG - else: - raise ValueError("Invalid MPAX algorithm") - solver = alg(warm_start=self.warm_start, **options) - self.solver = jax.jit(solver.optimize) + # Initialize solver + self.solver, self.warm_start = _initialize_solver(options) def jax_to_data( self, @@ -151,6 +210,107 @@ def torch_to_data( ) +def _extract_rhs_vectors( + con_vals_i: jnp.ndarray, ctx: "MPAX_ctx" +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Extract and reconstruct b and h right-hand-side vectors from constraint values. + + CVXPY stores RHS values sparsely in the last column. This reconstructs + dense b (equality) and h (inequality) vectors from the sparse representation. + + Args: + con_vals_i: Constraint coefficient values for single batch element + ctx: MPAX context with structure information + + Returns: + Tuple of (b_vals, h_vals) where: + - b_vals: Dense equality constraint RHS vector + - h_vals: Dense inequality constraint RHS vector + """ + # Extract sparse RHS values from last column + rhs_sparse_values = con_vals_i[ctx.last_col_start : ctx.last_col_end] + rhs_row_indices = ctx.last_col_indices + + num_eq_constraints = ctx.A_shape[0] + num_ineq_constraints = ctx.G_shape[0] + + # Split sparse values between equality (b) and inequality (h) constraints + split_at = ctx.split_at # Precomputed split index + + b_row_indices = rhs_row_indices[:split_at] + b_sparse_values = rhs_sparse_values[:split_at] + + h_row_indices = rhs_row_indices[split_at:] - num_eq_constraints + h_sparse_values = rhs_sparse_values[split_at:] + + # Reconstruct dense vectors from sparse representation + b_vals = jnp.zeros(num_eq_constraints) + h_vals = jnp.zeros(num_ineq_constraints) + + # Note: Negation matches MPAX's sign convention + b_vals = b_vals.at[b_row_indices].set(-b_sparse_values) + h_vals = h_vals.at[h_row_indices].set(-h_sparse_values) + + return b_vals, h_vals + + +def _build_and_solve_qp( + quad_obj_vals_i: jnp.ndarray, + lin_obj_vals_i: jnp.ndarray, + con_vals_i: jnp.ndarray, + ctx: "MPAX_ctx", + initial_primal: jnp.ndarray | None, + initial_dual: jnp.ndarray | None, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Build and solve a quadratic program for a single batch element. + + Constructs an MPAX QP model from parameter values and solves it using + the precompiled solver. + + Args: + quad_obj_vals_i: Quadratic objective coefficient values + lin_obj_vals_i: Linear objective coefficient values + con_vals_i: Constraint coefficient values + ctx: MPAX context with problem structure + initial_primal: Optional warm-start primal solution + initial_dual: Optional warm-start dual solution + + Returns: + Tuple of (primal_solution, dual_solution) + """ + # Extract RHS values and reconstruct b and h vectors + b_vals, h_vals = _extract_rhs_vectors(con_vals_i, ctx) + + # Build QP model: minimize (1/2)x^T Q x + c^T x subject to Ax = b, Gx <= h, l <= x <= u + model = mpax.create_qp( + jax.experimental.sparse.BCSR( + (quad_obj_vals_i[ctx.Q_idxs], *ctx.Q_structure), + shape=ctx.Q_shape, + ), + lin_obj_vals_i[ctx.c_slice], + jax.experimental.sparse.BCSR( + (con_vals_i[ctx.A_idxs], *ctx.A_structure), + shape=ctx.A_shape, + ), + b_vals, + jax.experimental.sparse.BCSR( + (con_vals_i[ctx.G_idxs], *ctx.G_structure), + shape=ctx.G_shape, + ), + h_vals, + ctx.lower, + ctx.upper, + ) + + # Solve with optional warm start + solution = ctx.solver( + model, + initial_primal_solution=initial_primal, + initial_dual_solution=initial_dual, + ) + return solution.primal_solution, solution.dual_solution + + @dataclass class MPAX_data: ctx: "MPAX_ctx" # Reference to context with structure info @@ -170,57 +330,14 @@ def jax_solve(self, solver_args=None): def solve_single_batch(quad_obj_vals_i, lin_obj_vals_i, con_vals_i): """Build model and solve for a single batch element.""" - # Extract RHS values and reconstruct b and h vectors - # (same logic as old jax_to_data, but for single batch element) - rhs_sparse_values = con_vals_i[self.ctx.last_col_start : self.ctx.last_col_end] - rhs_row_indices = self.ctx.last_col_indices - - num_eq_constraints = self.ctx.A_shape[0] - num_ineq_constraints = self.ctx.G_shape[0] - - # Use precomputed split_at from context - split_at = self.ctx.split_at - - b_row_indices = rhs_row_indices[:split_at] - b_sparse_values = rhs_sparse_values[:split_at] - - h_row_indices = rhs_row_indices[split_at:] - num_eq_constraints - h_sparse_values = rhs_sparse_values[split_at:] - - b_vals = jnp.zeros(num_eq_constraints) - h_vals = jnp.zeros(num_ineq_constraints) - - b_vals = b_vals.at[b_row_indices].set(-b_sparse_values) - h_vals = h_vals.at[h_row_indices].set(-h_sparse_values) - - # Build QP model - model = mpax.create_qp( - jax.experimental.sparse.BCSR( - (quad_obj_vals_i[self.ctx.Q_idxs], *self.ctx.Q_structure), - shape=self.ctx.Q_shape, - ), - lin_obj_vals_i[self.ctx.c_slice], - jax.experimental.sparse.BCSR( - (con_vals_i[self.ctx.A_idxs], *self.ctx.A_structure), - shape=self.ctx.A_shape, - ), - b_vals, - jax.experimental.sparse.BCSR( - (con_vals_i[self.ctx.G_idxs], *self.ctx.G_structure), - shape=self.ctx.G_shape, - ), - h_vals, - self.ctx.lower, - self.ctx.upper, - ) - - # Solve with optional warm start - solution = self.ctx.solver( - model, - initial_primal_solution=initial_primal, - initial_dual_solution=initial_dual, + return _build_and_solve_qp( + quad_obj_vals_i, + lin_obj_vals_i, + con_vals_i, + self.ctx, + initial_primal, + initial_dual, ) - return solution.primal_solution, solution.dual_solution # Vectorize over batch dimension (axis 1 of parameter arrays) solve_batched = jax.vmap(solve_single_batch, in_axes=(1, 1, 1)) diff --git a/src/cvxpylayers/jax/cvxpylayer.py b/src/cvxpylayers/jax/cvxpylayer.py index 5b8b945..57aced6 100644 --- a/src/cvxpylayers/jax/cvxpylayer.py +++ b/src/cvxpylayers/jax/cvxpylayer.py @@ -9,6 +9,122 @@ import cvxpylayers.utils.parse_args as pa +def _apply_gp_log_transform( + params: tuple[jnp.ndarray, ...], + ctx: pa.LayersContext, +) -> tuple[jnp.ndarray, ...]: + """Apply log transformation to geometric program (GP) parameters. + + Geometric programs are solved in log-space after conversion to DCP. + This function applies log transformation to the appropriate parameters. + + Args: + params: Tuple of parameter arrays in original GP space + ctx: Layer context containing GP parameter mapping info + + Returns: + Tuple of transformed parameters (log-space for GP params, unchanged otherwise) + """ + if not ctx.gp or not ctx.gp_param_to_log_param: + return params + + params_transformed = [] + for i, param in enumerate(params): + cvxpy_param = ctx.parameters[i] + if cvxpy_param in ctx.gp_param_to_log_param: + # This parameter needs log transformation for GP + params_transformed.append(jnp.log(param)) + else: + params_transformed.append(param) + return tuple(params_transformed) + + +def _flatten_and_batch_params( + params: tuple[jnp.ndarray, ...], + ctx: pa.LayersContext, + batch: tuple, +) -> jnp.ndarray: + """Flatten and batch parameters into a single stacked array. + + Converts a tuple of parameter arrays (potentially with mixed batched/unbatched) + into a single concatenated array suitable for matrix multiplication with the + parametrized problem matrices. + + Args: + params: Tuple of parameter arrays + ctx: Layer context with batch info and ordering + batch: Batch dimensions tuple (empty if unbatched) + + Returns: + Concatenated parameter array with shape (num_params, batch_size) or (num_params,) + """ + flattened_params: list[jnp.ndarray | None] = [None] * (len(params) + 1) + + for i, param in enumerate(params): + # Check if this parameter is batched or needs broadcasting + if ctx.batch_sizes[i] == 0 and batch: + # Unbatched parameter - expand to match batch size + param_expanded = jnp.expand_dims(param, 0) + param_expanded = jnp.broadcast_to(param_expanded, batch + param.shape) + flattened_params[ctx.user_order_to_col_order[i]] = jnp.reshape( + param_expanded, + batch + (-1,), + order="F", + ) + else: + # Already batched or no batch dimension needed + flattened_params[ctx.user_order_to_col_order[i]] = jnp.reshape( + param, + batch + (-1,), + order="F", + ) + + # Add constant 1.0 column for offset terms in canonical form + flattened_params[-1] = jnp.ones(batch + (1,), dtype=params[0].dtype) + assert all(p is not None for p in flattened_params), "All parameters must be assigned" + + p_stack = jnp.concatenate(cast(list[jnp.ndarray], flattened_params), -1) + # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size) + if batch: + p_stack = p_stack.T + return p_stack + + +def _recover_results( + primal: jnp.ndarray, + dual: jnp.ndarray, + ctx: pa.LayersContext, + batch: tuple, +) -> tuple[jnp.ndarray, ...]: + """Recover variable values from primal/dual solutions. + + Extracts the requested variables from the solver's primal and dual + solutions, applies inverse GP transformation if needed, and removes + batch dimension for unbatched inputs. + + Args: + primal: Primal solution from solver + dual: Dual solution from solver + ctx: Layer context with variable recovery info + batch: Batch dimensions tuple (empty if unbatched) + + Returns: + Tuple of recovered variable values + """ + # Extract each variable using its slice from the solution vectors + results = tuple(var.recover(primal, dual) for var in ctx.var_recover) + + # Apply exp transformation to recover from log-space for GP + if ctx.gp: + results = tuple(jnp.exp(r) for r in results) + + # Squeeze batch dimension for unbatched inputs + if not batch: + results = tuple(jnp.squeeze(r, 0) for r in results) + + return results + + class CvxpyLayer: def __init__( self, @@ -21,7 +137,6 @@ def __init__( canon_backend: str | None = None, solver_args: dict[str, Any] | None = None, ) -> None: - assert gp is False if solver_args is None: solver_args = {} self.ctx = pa.parse_args( @@ -47,33 +162,14 @@ def __call__( if solver_args is None: solver_args = {} batch = self.ctx.validate_params(list(params)) - flattened_params: list[jnp.ndarray | None] = [None] * (len(params) + 1) - for i, param in enumerate(params): - # Check if this parameter is batched or needs broadcasting - if self.ctx.batch_sizes[i] == 0 and batch: - # Unbatched parameter - expand to match batch size - # Add batch dimension by repeating - param_expanded = jnp.expand_dims(param, 0) - param_expanded = jnp.broadcast_to(param_expanded, batch + param.shape) - flattened_params[self.ctx.user_order_to_col_order[i]] = jnp.reshape( - param_expanded, - batch + (-1,), - order="F", - ) - else: - # Already batched or no batch dimension needed - flattened_params[self.ctx.user_order_to_col_order[i]] = jnp.reshape( - param, - batch + (-1,), - order="F", - ) - flattened_params[-1] = jnp.ones(batch + (1,), dtype=params[0].dtype) - # Assert all parameters have been assigned (no Nones remain) - assert all(p is not None for p in flattened_params), "All parameters must be assigned" - p_stack = jnp.concatenate(cast(list[jnp.ndarray], flattened_params), -1) - # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size) - if batch: - p_stack = p_stack.T + + # Apply log transformation to GP parameters + params = _apply_gp_log_transform(params, self.ctx) + + # Flatten and batch parameters + p_stack = _flatten_and_batch_params(params, self.ctx, batch) + + # Evaluate parametrized matrices P_eval = self.P @ p_stack if self.P is not None else None q_eval = self.q @ p_stack A_eval = self.A @ p_stack @@ -108,13 +204,9 @@ def solve_problem_bwd(res, g): solve_problem.defvjp(solve_problem_fwd, solve_problem_bwd) primal, dual = solve_problem(P_eval, q_eval, A_eval) - results = tuple(var.recover(primal, dual) for var in self.ctx.var_recover) - - # Squeeze batch dimension for unbatched inputs - if not batch: - results = tuple(jnp.squeeze(r, 0) for r in results) - return results + # Recover results and apply GP inverse transform if needed + return _recover_results(primal, dual, self.ctx, batch) def scipy_csr_to_jax_bcsr( diff --git a/src/cvxpylayers/torch/cvxpylayer.py b/src/cvxpylayers/torch/cvxpylayer.py index b6aa2f4..bd4599a 100644 --- a/src/cvxpylayers/torch/cvxpylayer.py +++ b/src/cvxpylayers/torch/cvxpylayer.py @@ -7,6 +7,123 @@ import cvxpylayers.utils.parse_args as pa +def _apply_gp_log_transform( + params: tuple[torch.Tensor, ...], + ctx: pa.LayersContext, +) -> tuple[torch.Tensor, ...]: + """Apply log transformation to geometric program (GP) parameters. + + Geometric programs are solved in log-space after conversion to DCP. + This function applies log transformation to the appropriate parameters. + + Args: + params: Tuple of parameter tensors in original GP space + ctx: Layer context containing GP parameter mapping info + + Returns: + Tuple of transformed parameters (log-space for GP params, unchanged otherwise) + """ + if not ctx.gp or not ctx.gp_param_to_log_param: + return params + + params_transformed = [] + for i, param in enumerate(params): + cvxpy_param = ctx.parameters[i] + if cvxpy_param in ctx.gp_param_to_log_param: + # This parameter needs log transformation for GP + params_transformed.append(torch.log(param)) + else: + params_transformed.append(param) + return tuple(params_transformed) + + +def _flatten_and_batch_params( + params: tuple[torch.Tensor, ...], + ctx: pa.LayersContext, + batch: tuple, +) -> torch.Tensor: + """Flatten and batch parameters into a single stacked tensor. + + Converts a tuple of parameter tensors (potentially with mixed batched/unbatched) + into a single concatenated tensor suitable for matrix multiplication with the + parametrized problem matrices. + + Args: + params: Tuple of parameter tensors + ctx: Layer context with batch info and ordering + batch: Batch dimensions tuple (empty if unbatched) + + Returns: + Concatenated parameter tensor with shape (num_params, batch_size) or (num_params,) + """ + flattened_params: list[torch.Tensor | None] = [None] * (len(params) + 1) + + for i, param in enumerate(params): + # Check if this parameter is batched or needs broadcasting + if ctx.batch_sizes[i] == 0 and batch: + # Unbatched parameter - expand to match batch size + param_expanded = param.unsqueeze(0).expand(batch + param.shape) + flattened_params[ctx.user_order_to_col_order[i]] = reshape_fortran( + param_expanded, + batch + (-1,), + ) + else: + # Already batched or no batch dimension needed + flattened_params[ctx.user_order_to_col_order[i]] = reshape_fortran( + param, + batch + (-1,), + ) + + # Add constant 1.0 column for offset terms in canonical form + flattened_params[-1] = torch.ones( + batch + (1,), + dtype=params[0].dtype, + device=params[0].device, + ) + assert all(p is not None for p in flattened_params), "All parameters must be assigned" + + p_stack = torch.cat(cast(list[torch.Tensor], flattened_params), -1) + # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size) + if batch: + p_stack = p_stack.T + return p_stack + + +def _recover_results( + primal: torch.Tensor, + dual: torch.Tensor, + ctx: pa.LayersContext, + batch: tuple, +) -> tuple[torch.Tensor, ...]: + """Recover variable values from primal/dual solutions. + + Extracts the requested variables from the solver's primal and dual + solutions, applies inverse GP transformation if needed, and removes + batch dimension for unbatched inputs. + + Args: + primal: Primal solution from solver + dual: Dual solution from solver + ctx: Layer context with variable recovery info + batch: Batch dimensions tuple (empty if unbatched) + + Returns: + Tuple of recovered variable values + """ + # Extract each variable using its slice from the solution vectors + results = tuple(var.recover(primal, dual) for var in ctx.var_recover) + + # Apply exp transformation to recover from log-space for GP + if ctx.gp: + results = tuple(torch.exp(r) for r in results) + + # Squeeze batch dimension for unbatched inputs + if not batch: + results = tuple(r.squeeze(0) for r in results) + + return results + + class CvxpyLayer(torch.nn.Module): def __init__( self, @@ -20,7 +137,6 @@ def __init__( solver_args: dict[str, Any] | None = None, ) -> None: super().__init__() - assert gp is False if solver_args is None: solver_args = {} self.ctx = pa.parse_args( @@ -48,37 +164,19 @@ def forward( if solver_args is None: solver_args = {} batch = self.ctx.validate_params(list(params)) - flattened_params: list[torch.Tensor | None] = [None] * (len(params) + 1) - for i, param in enumerate(params): - # Check if this parameter is batched or needs broadcasting - if self.ctx.batch_sizes[i] == 0 and batch: - # Unbatched parameter - expand to match batch size - # Add batch dimension by repeating - param_expanded = param.unsqueeze(0).expand(batch + param.shape) - flattened_params[self.ctx.user_order_to_col_order[i]] = reshape_fortran( - param_expanded, - batch + (-1,), - ) - else: - # Already batched or no batch dimension needed - flattened_params[self.ctx.user_order_to_col_order[i]] = reshape_fortran( - param, - batch + (-1,), - ) - flattened_params[-1] = torch.ones( - batch + (1,), - dtype=params[0].dtype, - device=params[0].device, - ) - # Assert all parameters have been assigned (no Nones remain) - assert all(p is not None for p in flattened_params), "All parameters must be assigned" - p_stack = torch.cat(cast(list[torch.Tensor], flattened_params), -1) - # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size) - if batch: - p_stack = p_stack.T + + # Apply log transformation to GP parameters + params = _apply_gp_log_transform(params, self.ctx) + + # Flatten and batch parameters + p_stack = _flatten_and_batch_params(params, self.ctx, batch) + + # Evaluate parametrized matrices P_eval = self.P @ p_stack if self.P is not None else None q_eval = self.q @ p_stack A_eval = self.A @ p_stack + + # Solve optimization problem primal, dual, _, _ = _CvxpyLayer.apply( # type: ignore[misc] P_eval, q_eval, @@ -86,13 +184,9 @@ def forward( self.ctx, solver_args, ) - results = tuple(var.recover(primal, dual) for var in self.ctx.var_recover) - - # Squeeze batch dimension for unbatched inputs - if not batch: - results = tuple(r.squeeze(0) for r in results) - return results + # Recover results and apply GP inverse transform if needed + return _recover_results(primal, dual, self.ctx, batch) class _CvxpyLayer(torch.autograd.Function): diff --git a/src/cvxpylayers/utils/dgp_reduction.py b/src/cvxpylayers/utils/dgp_reduction.py new file mode 100644 index 0000000..d10e49b --- /dev/null +++ b/src/cvxpylayers/utils/dgp_reduction.py @@ -0,0 +1,88 @@ +"""Custom DGP to DCP reduction that works without parameter values. + +This module provides a custom implementation of CVXPY's DGP→DCP reduction +that allows cvxpylayers to build computation graphs without requiring +parameter values to be set upfront. +""" + +from typing import Any + +import cvxpy as cp +import numpy as np +from cvxpy.reductions.dgp2dcp.canonicalizers import DgpCanonMethods + + +class _DgpCanonMethodsNoValueCheck(DgpCanonMethods): # type: ignore[misc] + """Custom DGP canonicalization methods that work without parameter values.""" + + def parameter_canon( + self, parameter: cp.Parameter, args: list[Any] + ) -> tuple[cp.Parameter, list[Any]]: + """Canonicalize a parameter without requiring it to have a value. + + Args: + parameter: The parameter to canonicalize + args: Arguments (unused) + + Returns: + Tuple of (log-space parameter, constraints) + """ + del args + # Swaps out positive parameters for unconstrained parameters. + if parameter in self._parameters: + return self._parameters[parameter], [] + else: + # Create log-space parameter, preserving None value if present + log_parameter = cp.Parameter( + parameter.shape, + name=parameter.name(), + value=np.log(parameter.value) if parameter.value is not None else None, + ) + self._parameters[parameter] = log_parameter + return log_parameter, [] + + +class _Dgp2DcpNoValueCheck(cp.reductions.Dgp2Dcp): # type: ignore[misc] + """DGP to DCP reduction that works without parameter values. + + This is an internal cvxpylayers class that bypasses CVXPY's requirement + for parameters to have values during the DGP→DCP transformation. + + CVXPY's Dgp2Dcp.accepts() checks that all parameters have values, but + this is unnecessary - the transformation is purely symbolic and doesn't + actually need the values until solve time. + + This class is NOT monkey patching - it's a separate class used only + within cvxpylayers. CVXPY's original Dgp2Dcp remains unchanged. + """ + + def accepts(self, problem: cp.Problem) -> bool: + """Accept DGP problems even without parameter values. + + Args: + problem: The CVXPY problem to check + + Returns: + True if the problem is DGP, False otherwise + """ + return problem.is_dgp() + + def apply(self, problem: cp.Problem) -> tuple[cp.Problem, Any]: + """Apply DGP to DCP reduction using custom canon methods. + + Args: + problem: The DGP problem to reduce + + Returns: + Tuple of (DCP problem, inverse data) + """ + if not self.accepts(problem): + raise ValueError("The supplied problem is not DGP.") + + # Use our custom canon methods that handle None parameter values + self.canon_methods = _DgpCanonMethodsNoValueCheck() + equiv_problem, inverse_data = super(cp.reductions.Dgp2Dcp, self).apply( # type: ignore[misc] + problem + ) + inverse_data._problem = problem + return equiv_problem, inverse_data diff --git a/src/cvxpylayers/utils/parse_args.py b/src/cvxpylayers/utils/parse_args.py index 9787cd6..a2cb840 100644 --- a/src/cvxpylayers/utils/parse_args.py +++ b/src/cvxpylayers/utils/parse_args.py @@ -3,8 +3,10 @@ import cvxpy as cp import scipy.sparse +from cvxpy.reductions.dcp2cone.cone_matrix_stuffing import ParamConeProg import cvxpylayers.interfaces +from cvxpylayers.utils.dgp_reduction import _Dgp2DcpNoValueCheck if TYPE_CHECKING: import torch @@ -80,6 +82,11 @@ class LayersContext: batch_sizes: list[int] | None = ( None # Track which params are batched (0=unbatched, N=batch size) ) + # GP (Geometric Programming) support + gp: bool = False + # Maps original GP parameters to their log-space DCP parameters + # Used to determine which parameters need log transformation in forward pass + gp_param_to_log_param: dict[cp.Parameter, cp.Parameter] | None = None def validate_params(self, values: list) -> tuple: if len(values) != len(self.parameters): @@ -133,19 +140,32 @@ def validate_params(self, values: list) -> tuple: return () -def parse_args( +def _validate_problem( problem: cp.Problem, variables: list[cp.Variable], parameters: list[cp.Parameter], - solver: str | None, - gp: bool = False, - verbose: bool = False, - canon_backend: str | None = None, - solver_args: dict[str, Any] | None = None, -) -> LayersContext: - if not problem.is_dcp(dpp=True): # type: ignore[call-arg] - raise ValueError("Problem must be DPP.") + gp: bool, +) -> None: + """Validate that the problem is DPP-compliant and inputs are well-formed. + + Args: + problem: CVXPY problem to validate + variables: List of CVXPY variables to track + parameters: List of CVXPY parameters + gp: Whether this is a geometric program (GP) + + Raises: + ValueError: If problem is not DPP-compliant or inputs are invalid + """ + # Check if problem follows disciplined parametrized programming (DPP) rules + if gp: + if not problem.is_dgp(dpp=True): # type: ignore[call-arg] + raise ValueError("Problem must be DPP for geometric programming.") + else: + if not problem.is_dcp(dpp=True): # type: ignore[call-arg] + raise ValueError("Problem must be DPP.") + # Validate parameters match problem definition if not set(problem.parameters()) == set(parameters): raise ValueError("The layer's parameters must exactly match problem.parameters") if not set(variables).issubset(set(problem.variables())): @@ -155,14 +175,102 @@ def parse_args( if not isinstance(variables, list) and not isinstance(variables, tuple): raise ValueError("The layer's variables must be provided as a list or tuple") + +def _build_user_order_mapping( + parameters: list[cp.Parameter], + param_prob: ParamConeProg, + gp: bool, + gp_param_to_log_param: dict[cp.Parameter, cp.Parameter] | None, +) -> dict[int, int]: + """Build mapping from user parameter order to column order. + + CVXPY internally reorders parameters when canonicalizing problems. This + creates a mapping from the user's parameter order to the internal column + order used in the canonical form. + + Args: + parameters: List of CVXPY parameters in user order + param_prob: CVXPY's parametrized problem object + gp: Whether this is a geometric program + gp_param_to_log_param: Mapping from GP params to log-space DCP params + + Returns: + Dictionary mapping user parameter index to column order index + """ + # For GP problems, we need to use the log-space DCP parameter IDs + if gp and gp_param_to_log_param: + # Map user order index to column using log-space DCP parameters + user_order_to_col = { + i: param_prob.param_id_to_col[ + gp_param_to_log_param[p].id if p in gp_param_to_log_param else p.id + ] + for i, p in enumerate(parameters) + } + else: + # Standard DCP problem - use original parameters + user_order_to_col = { + i: col + for col, i in sorted( + [(param_prob.param_id_to_col[p.id], i) for i, p in enumerate(parameters)], + ) + } + + # Convert column indices to sequential order mapping + user_order_to_col_order = {} + for j, i in enumerate(user_order_to_col.keys()): + user_order_to_col_order[i] = j + + return user_order_to_col_order + + +def parse_args( + problem: cp.Problem, + variables: list[cp.Variable], + parameters: list[cp.Parameter], + solver: str | None, + gp: bool = False, + verbose: bool = False, + canon_backend: str | None = None, + solver_args: dict[str, Any] | None = None, +) -> LayersContext: + # Validate problem is DPP (disciplined parametrized programming) + _validate_problem(problem, variables, parameters, gp) + if solver is None: solver = "DIFFCP" - data, _, _ = problem.get_problem_data( - solver=solver, gp=gp, verbose=verbose, canon_backend=canon_backend, solver_opts=solver_args - ) + + # Handle GP problems using our custom reduction + gp_param_to_log_param = None + if gp: + # Apply custom DGP→DCP reduction that doesn't require parameter values + dgp2dcp = _Dgp2DcpNoValueCheck() + dcp_problem, _ = dgp2dcp.apply(problem) + + # Extract parameter mapping from the reduction + gp_param_to_log_param = dgp2dcp.canon_methods._parameters + + # Get problem data from the already-transformed DCP problem + data, _, _ = dcp_problem.get_problem_data( + solver=solver, + gp=False, + verbose=verbose, + canon_backend=canon_backend, + solver_opts=solver_args, + ) + else: + # Standard DCP path + data, _, _ = problem.get_problem_data( + solver=solver, + gp=False, + verbose=verbose, + canon_backend=canon_backend, + solver_opts=solver_args, + ) + param_prob = data[cp.settings.PARAM_PROB] # type: ignore[attr-defined] cone_dims = data["dims"] + # Create solver context solver_ctx = cvxpylayers.interfaces.get_solver_ctx( solver, param_prob, @@ -170,15 +278,11 @@ def parse_args( data, solver_args, ) - user_order_to_col = { - i: col - for col, i in sorted( - [(param_prob.param_id_to_col[p.id], i) for i, p in enumerate(parameters)], - ) - } - user_order_to_col_order = {} - for j, i in enumerate(user_order_to_col.keys()): - user_order_to_col_order[i] = j + + # Build parameter ordering mapping + user_order_to_col_order = _build_user_order_mapping( + parameters, param_prob, gp, gp_param_to_log_param + ) q = getattr(param_prob, "q", getattr(param_prob, "c", None)) @@ -197,4 +301,6 @@ def parse_args( for v in variables ], user_order_to_col_order=user_order_to_col_order, + gp=gp, + gp_param_to_log_param=gp_param_to_log_param, ) diff --git a/tests/test_jax.py b/tests/test_jax.py index b2f0688..97cfe27 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -399,7 +399,6 @@ def test_equality(): check_grads(layer, [b_jax], order=1, modes=["rev"]) -@pytest.mark.skip(reason="gp=True (geometric programming) not supported in JAX") def test_basic_gp(): x = cp.Variable(pos=True) y = cp.Variable(pos=True) @@ -412,7 +411,7 @@ def test_basic_gp(): objective_fn = 1 / (x * y * z) constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] problem = cp.Problem(cp.Minimize(objective_fn), constraints) - problem.solve(cp.SCS, gp=True) + problem.solve(cp.CLARABEL, gp=True) layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) a_jax = jnp.array(2.0) @@ -437,3 +436,118 @@ def test_basic_gp(): order=1, modes=["rev"], ) + + +def test_batched_gp(): + """Test GP with batched parameters.""" + x = cp.Variable(pos=True) + y = cp.Variable(pos=True) + z = cp.Variable(pos=True) + + # Batched parameters (need initial values for GP) + a = cp.Parameter(pos=True, value=2.0) + b = cp.Parameter(pos=True, value=1.0) + c = cp.Parameter(value=0.5) + + # Objective and constraints + objective_fn = 1 / (x * y * z) + constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] + problem = cp.Problem(cp.Minimize(objective_fn), constraints) + + # Create layer + layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) + + # Batched parameters - test with batch size 4 + # For scalar parameters, batching means 1D arrays + batch_size = 4 + a_batch = jnp.array([2.0, 1.5, 2.5, 1.8]) + b_batch = jnp.array([1.0, 1.2, 0.8, 1.5]) + c_batch = jnp.array([0.5, 0.6, 0.4, 0.5]) + + # Forward pass + x_batch, y_batch, z_batch = layer(a_batch, b_batch, c_batch) + + # Check shapes - batched results are (batch_size, 1) for scalar variables + assert x_batch.shape == (batch_size, 1) + assert y_batch.shape == (batch_size, 1) + assert z_batch.shape == (batch_size, 1) + + # Verify each batch element by solving individually + for i in range(batch_size): + a.value = float(a_batch[i]) + b.value = float(b_batch[i]) + c.value = float(c_batch[i]) + problem.solve(cp.CLARABEL, gp=True) + + assert np.allclose(x.value, x_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for x" + ) + assert np.allclose(y.value, y_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for y" + ) + assert np.allclose(z.value, z_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for z" + ) + + # Test gradients on batched problem + check_grads( + lambda a, b, c: jnp.sum( + layer(a, b, c, solver_args={"acceleration_lookback": 0})[0], + ), + [a_batch, b_batch, c_batch], + order=1, + modes=["rev"], + ) + + +def test_gp_without_param_values(): + """Test that GP layers can be created without setting parameter values.""" + x = cp.Variable(pos=True) + y = cp.Variable(pos=True) + z = cp.Variable(pos=True) + + # Create parameters WITHOUT setting values (this is the key test!) + a = cp.Parameter(pos=True, name="a") + b = cp.Parameter(pos=True, name="b") + c = cp.Parameter(name="c") + + # Build GP problem + objective_fn = 1 / (x * y * z) + constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] + problem = cp.Problem(cp.Minimize(objective_fn), constraints) + + # This should work WITHOUT needing to set a.value, b.value, c.value + layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) + + # Now use the layer with actual parameter values + a_jax = jnp.array(2.0) + b_jax = jnp.array(1.0) + c_jax = jnp.array(0.5) + + # Forward pass + x_jax, y_jax, z_jax = layer(a_jax, b_jax, c_jax) + + # Verify solution against CVXPY direct solve + a.value = 2.0 + b.value = 1.0 + c.value = 0.5 + problem.solve(cp.CLARABEL, gp=True) + + assert np.isclose(x.value, x_jax, atol=1e-5) + assert np.isclose(y.value, y_jax, atol=1e-5) + assert np.isclose(z.value, z_jax, atol=1e-5) + + # Test gradients + check_grads( + lambda a, b, c: jnp.sum( + layer( + a, + b, + c, + solver_args={"acceleration_lookback": 0}, + )[0], + ), + [a_jax, b_jax, c_jax], + order=1, + modes=["rev"], + ) diff --git a/tests/test_torch.py b/tests/test_torch.py index dc0d762..ed7b5d8 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -418,7 +418,6 @@ def test_equality(): torch.autograd.gradcheck(layer, b_th) -@pytest.mark.skip def test_basic_gp(): _ = set_seed(0) x = cp.Variable(pos=True) @@ -432,7 +431,7 @@ def test_basic_gp(): objective_fn = 1 / (x * y * z) constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] problem = cp.Problem(cp.Minimize(objective_fn), constraints) - problem.solve(cp.SCS, gp=True) + problem.solve(cp.CLARABEL, gp=True) layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) a_th = torch.tensor([2.0]).requires_grad_() @@ -451,6 +450,112 @@ def f(a, b, c): torch.autograd.gradcheck(f, (a_th, b_th, c_th), atol=1e-4) +def test_batched_gp(): + """Test GP with batched parameters.""" + _ = set_seed(0) + x = cp.Variable(pos=True) + y = cp.Variable(pos=True) + z = cp.Variable(pos=True) + + # Batched parameters (need initial values for GP) + a = cp.Parameter(pos=True, value=2.0) + b = cp.Parameter(pos=True, value=1.0) + c = cp.Parameter(value=0.5) + + # Objective and constraints + objective_fn = 1 / (x * y * z) + constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] + problem = cp.Problem(cp.Minimize(objective_fn), constraints) + + # Create layer + layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) + + # Batched parameters - test with batch size 4 (double precision) + # For scalar parameters, batching means 1D tensors + batch_size = 4 + a_batch = torch.tensor([2.0, 1.5, 2.5, 1.8], dtype=torch.float64, requires_grad=True) + b_batch = torch.tensor([1.0, 1.2, 0.8, 1.5], dtype=torch.float64, requires_grad=True) + c_batch = torch.tensor([0.5, 0.6, 0.4, 0.5], dtype=torch.float64, requires_grad=True) + + # Forward pass + x_batch, y_batch, z_batch = layer(a_batch, b_batch, c_batch) + + # Check shapes - batched results are (batch_size, 1) for scalar variables + assert x_batch.shape == (batch_size, 1) + assert y_batch.shape == (batch_size, 1) + assert z_batch.shape == (batch_size, 1) + + # Verify each batch element by solving individually + for i in range(batch_size): + a.value = a_batch[i].item() + b.value = b_batch[i].item() + c.value = c_batch[i].item() + problem.solve(cp.CLARABEL, gp=True) + + assert torch.allclose(torch.tensor(x.value), x_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for x" + ) + assert torch.allclose(torch.tensor(y.value), y_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for y" + ) + assert torch.allclose(torch.tensor(z.value), z_batch[i, 0], atol=1e-4, rtol=1e-4), ( + f"Mismatch in batch {i} for z" + ) + + # Test gradients on batched problem + def f_batch(a, b, c): + res = layer(a, b, c, solver_args={"acceleration_lookback": 0}) + return res[0].sum() + + torch.autograd.gradcheck(f_batch, (a_batch, b_batch, c_batch), atol=1e-3, rtol=1e-3) + + +def test_gp_without_param_values(): + """Test that GP layers can be created without setting parameter values.""" + _ = set_seed(0) + x = cp.Variable(pos=True) + y = cp.Variable(pos=True) + z = cp.Variable(pos=True) + + # Create parameters WITHOUT setting values (this is the key test!) + a = cp.Parameter(pos=True, name="a") + b = cp.Parameter(pos=True, name="b") + c = cp.Parameter(name="c") + + # Build GP problem + objective_fn = 1 / (x * y * z) + constraints = [a * (x * y + x * z + y * z) <= b, x >= y**c] + problem = cp.Problem(cp.Minimize(objective_fn), constraints) + + # This should work WITHOUT needing to set a.value, b.value, c.value + layer = CvxpyLayer(problem, parameters=[a, b, c], variables=[x, y, z], gp=True) + + # Now use the layer with actual parameter values + a_th = torch.tensor([2.0], dtype=torch.float64, requires_grad=True) + b_th = torch.tensor([1.0], dtype=torch.float64, requires_grad=True) + c_th = torch.tensor([0.5], dtype=torch.float64, requires_grad=True) + + # Forward pass + x_th, y_th, z_th = layer(a_th, b_th, c_th) + + # Verify solution against CVXPY direct solve + a.value = 2.0 + b.value = 1.0 + c.value = 0.5 + problem.solve(cp.CLARABEL, gp=True) + + assert torch.allclose(torch.tensor(x.value), x_th, atol=1e-5) + assert torch.allclose(torch.tensor(y.value), y_th, atol=1e-5) + assert torch.allclose(torch.tensor(z.value), z_th, atol=1e-5) + + # Test gradients + def f(a, b, c): + res = layer(a, b, c, solver_args={"acceleration_lookback": 0}) + return res[0].sum() + + torch.autograd.gradcheck(f, (a_th, b_th, c_th), atol=1e-4) + + def test_no_grad_context(): n, m = 2, 3 x = cp.Variable(n)