Coverage for src / cvxcla / types.py: 100%
101 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-03 01:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-03 01:26 +0000
1# Copyright 2023 Stanford University Convex Optimization Group
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Type definitions and classes for the Critical Line Algorithm.
16This module defines the core data structures used in the Critical Line Algorithm:
17- FrontierPoint: Represents a point on the efficient frontier.
18- TurningPoint: Represents a turning point on the efficient frontier.
19- Frontier: Represents the entire efficient frontier.
21It also defines type aliases for commonly used types.
22"""
24from __future__ import annotations
26from collections.abc import Iterator
27from dataclasses import dataclass, field
28from typing import TYPE_CHECKING
30if TYPE_CHECKING:
31 import plotly.graph_objects as go
33import numpy as np
34from numpy.typing import NDArray
36from .optimize import minimize
39@dataclass(frozen=True)
40class FrontierPoint:
41 """A point on the efficient frontier.
43 This class represents a portfolio on the efficient frontier, defined by its weights.
44 It provides methods to compute the expected return and variance of the portfolio.
46 Attributes:
47 weights: Vector of portfolio weights for each asset.
49 """
51 weights: NDArray[np.float64]
53 def __post_init__(self) -> None:
54 """Validate that the weights sum to 1.
56 This method is automatically called after initialization to ensure that
57 the portfolio weights sum to 1, which is required for a valid portfolio.
59 Raises:
60 AssertionError: If the sum of weights is not close to 1.
62 """
63 # check that the sum is close to 1
64 if not np.isclose(np.sum(self.weights), 1.0):
65 msg = "Weights do not sum to 1"
66 raise ValueError(msg)
68 def mean(self, mean: NDArray[np.float64]) -> float:
69 """Compute the expected return of the portfolio.
71 Args:
72 mean: Vector of expected returns for each asset.
74 Returns:
75 The expected return of the portfolio.
77 """
78 return float(mean.T @ self.weights)
80 def variance(self, covariance: NDArray[np.float64]) -> float:
81 """Compute the expected variance of the portfolio.
83 Args:
84 covariance: Covariance matrix of asset returns.
86 Returns:
87 The expected variance of the portfolio.
89 """
90 return float(self.weights.T @ covariance @ self.weights)
93@dataclass(frozen=True)
94class TurningPoint(FrontierPoint):
95 """Turning point.
97 A turning point is a vector of weights, a lambda value, and a boolean vector
98 indicating which assets are free. All assets that are not free are blocked.
99 """
101 free: NDArray[np.bool_]
102 lamb: float = np.inf
104 @property
105 def free_indices(self) -> np.ndarray:
106 """Returns the indices of the free assets."""
107 return np.where(self.free)[0]
109 @property
110 def blocked_indices(self) -> np.ndarray:
111 """Returns the indices of the blocked assets."""
112 return np.where(~self.free)[0]
115@dataclass(frozen=True)
116class Frontier:
117 """A frontier is a list of frontier points. Some of them might be turning points."""
119 mean: NDArray[np.float64]
120 covariance: NDArray[np.float64]
121 frontier: list[FrontierPoint] = field(default_factory=list)
123 def interpolate(self, num: int = 100) -> Frontier:
124 """Interpolate the frontier with additional points between existing points.
126 This method creates a new Frontier object with additional points interpolated
127 between the existing points. This is useful for creating a smoother representation
128 of the efficient frontier for visualization or analysis.
130 Args:
131 num: The number of points to use in the interpolation. The method will create
132 num-1 new points between each pair of adjacent existing points.
134 Returns:
135 A new Frontier object with the interpolated points.
137 """
139 def _interpolate() -> Iterator[FrontierPoint]:
140 for w_right, w_left in zip(self.weights[0:-1], self.weights[1:], strict=False):
141 for lamb in np.linspace(0, 1, num):
142 if lamb > 0:
143 yield FrontierPoint(weights=lamb * w_left + (1 - lamb) * w_right)
145 points = list(_interpolate())
146 return Frontier(frontier=points, mean=self.mean, covariance=self.covariance)
148 def __iter__(self) -> Iterator[FrontierPoint]:
149 """Iterate over all frontier points."""
150 yield from self.frontier
152 def __len__(self) -> int:
153 """Give number of frontier points."""
154 return len(self.frontier)
156 @property
157 def weights(self) -> np.ndarray:
158 """Matrix of weights. One row per point."""
159 return np.array([point.weights for point in self])
161 @property
162 def returns(self) -> np.ndarray:
163 """Vector of expected returns."""
164 return np.array([point.mean(self.mean) for point in self])
166 @property
167 def variance(self) -> np.ndarray:
168 """Vector of expected variances."""
169 return np.array([point.variance(self.covariance) for point in self])
171 @property
172 def sharpe_ratio(self) -> np.ndarray:
173 """Vector of expected Sharpe ratios."""
174 return self.returns / self.volatility
176 @property
177 def volatility(self) -> np.ndarray:
178 """Vector of expected volatilities."""
179 return np.sqrt(self.variance)
181 @property
182 def max_sharpe(self) -> tuple[float, np.ndarray]:
183 """Maximal Sharpe ratio on the frontier.
185 Returns:
186 Tuple of maximal Sharpe ratio and the weights to achieve it
188 """
190 def neg_sharpe(alpha: float, *args: np.ndarray) -> float:
191 w_left, w_right = args[0], args[1]
192 # convex combination of left and right weights
193 weight = alpha * w_left + (1 - alpha) * w_right
194 # compute the variance
195 var = float(weight.T @ self.covariance @ weight)
196 returns = float(self.mean.T @ weight)
197 return float(-returns / np.sqrt(var))
199 sharpe_ratios = self.sharpe_ratio
201 # in which point is the maximal Sharpe ratio?
202 sr_position_max = np.argmax(self.sharpe_ratio)
204 # min only there for security...
205 right = min(sr_position_max + 1, len(self) - 1)
206 left = max(0, sr_position_max - 1)
208 # Look to the left and look to the right
210 if right > sr_position_max:
211 out = minimize(
212 neg_sharpe,
213 0.5,
214 args=(self.weights[sr_position_max], self.weights[right]),
215 bounds=((0, 1),),
216 )
217 var = out["x"][0]
218 w_right = var * self.weights[sr_position_max] + (1 - var) * self.weights[right]
219 sharpe_ratio_right = -out["fun"]
220 else:
221 w_right = self.weights[sr_position_max]
222 sharpe_ratio_right = sharpe_ratios[sr_position_max]
224 if left < sr_position_max:
225 out = minimize(
226 neg_sharpe,
227 0.5,
228 args=(self.weights[left], self.weights[sr_position_max]),
229 bounds=((0, 1),),
230 )
231 var = out["x"][0]
232 w_left = var * self.weights[left] + (1 - var) * self.weights[sr_position_max]
233 sharpe_ratio_left = -out["fun"]
234 else:
235 w_left = self.weights[sr_position_max]
236 sharpe_ratio_left = sharpe_ratios[sr_position_max]
238 if sharpe_ratio_left > sharpe_ratio_right:
239 return sharpe_ratio_left, w_left
241 return sharpe_ratio_right, w_right
243 def plot(self, volatility: bool = False, markers: bool = True) -> go.Figure:
244 """Plot the efficient frontier.
246 This function creates a line plot of the efficient frontier, with expected return
247 on the y-axis and either variance or volatility on the x-axis.
249 Args:
250 volatility: If True, plot volatility (standard deviation) on the x-axis.
251 If False, plot variance on the x-axis.
252 markers: If True, show markers at each point on the frontier.
254 Returns:
255 A plotly Figure object that can be displayed or saved.
257 """
258 try:
259 import plotly.graph_objects as go
260 except ImportError as e:
261 msg = "Plotting requires plotly. Install it with: pip install cvxcla[plot]"
262 raise ImportError(msg) from e
264 fig = go.Figure()
266 x = self.volatility if volatility else self.variance
267 axis_title = "Expected volatility" if volatility else "Expected variance"
269 fig.add_trace(
270 go.Scatter(x=x, y=self.returns, mode="lines+markers" if markers else "lines", name="Efficient Frontier")
271 )
273 fig.update_layout(
274 xaxis_title=axis_title,
275 yaxis_title="Expected Return",
276 )
278 return fig