Coverage for src/cvxcla/types.py: 76%
93 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 05:48 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 05:48 +0000
1# Copyright 2023 Stanford University Convex Optimization Group
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Type definitions and classes for the Critical Line Algorithm.
16This module defines the core data structures used in the Critical Line Algorithm:
17- FrontierPoint: Represents a point on the efficient frontier.
18- TurningPoint: Represents a turning point on the efficient frontier.
19- Frontier: Represents the entire efficient frontier.
21It also defines type aliases for commonly used types.
22"""
24from __future__ import annotations
26from collections.abc import Iterator
27from dataclasses import dataclass, field
29import numpy as np
30import plotly.graph_objects as go
31from numpy.typing import NDArray
33from .optimize import minimize
36@dataclass(frozen=True)
37class FrontierPoint:
38 """A point on the efficient frontier.
40 This class represents a portfolio on the efficient frontier, defined by its weights.
41 It provides methods to compute the expected return and variance of the portfolio.
43 Attributes:
44 weights: Vector of portfolio weights for each asset.
46 """
48 weights: NDArray[np.float64]
50 def __post_init__(self):
51 """Validate that the weights sum to 1.
53 This method is automatically called after initialization to ensure that
54 the portfolio weights sum to 1, which is required for a valid portfolio.
56 Raises:
57 AssertionError: If the sum of weights is not close to 1.
59 """
60 # check that the sum is close to 1
61 assert np.isclose(np.sum(self.weights), 1.0)
63 def mean(self, mean: NDArray[np.float64]) -> float:
64 """Compute the expected return of the portfolio.
66 Args:
67 mean: Vector of expected returns for each asset.
69 Returns:
70 The expected return of the portfolio.
72 """
73 return float(mean.T @ self.weights)
75 def variance(self, covariance: NDArray[np.float64]) -> float:
76 """Compute the expected variance of the portfolio.
78 Args:
79 covariance: Covariance matrix of asset returns.
81 Returns:
82 The expected variance of the portfolio.
84 """
85 return float(self.weights.T @ covariance @ self.weights)
88@dataclass(frozen=True)
89class TurningPoint(FrontierPoint):
90 """Turning point.
92 A turning point is a vector of weights, a lambda value, and a boolean vector
93 indicating which assets are free. All assets that are not free are blocked.
94 """
96 free: NDArray[np.bool_]
97 lamb: float = np.inf
99 @property
100 def free_indices(self) -> np.ndarray:
101 """Returns the indices of the free assets."""
102 return np.where(self.free)[0]
104 @property
105 def blocked_indices(self) -> np.ndarray:
106 """Returns the indices of the blocked assets."""
107 return np.where(~self.free)[0]
110@dataclass(frozen=True)
111class Frontier:
112 """A frontier is a list of frontier points. Some of them might be turning points."""
114 mean: NDArray[np.float64]
115 covariance: NDArray[np.float64]
116 frontier: list[FrontierPoint] = field(default_factory=list)
118 def interpolate(self, num=100) -> Frontier:
119 """Interpolate the frontier with additional points between existing points.
121 This method creates a new Frontier object with additional points interpolated
122 between the existing points. This is useful for creating a smoother representation
123 of the efficient frontier for visualization or analysis.
125 Args:
126 num: The number of points to use in the interpolation. The method will create
127 num-1 new points between each pair of adjacent existing points.
129 Returns:
130 A new Frontier object with the interpolated points.
132 """
134 def _interpolate():
135 for w_right, w_left in zip(self.weights[0:-1], self.weights[1:]):
136 for lamb in np.linspace(0, 1, num):
137 if lamb > 0:
138 yield FrontierPoint(weights=lamb * w_left + (1 - lamb) * w_right)
140 points = list(_interpolate())
141 return Frontier(frontier=points, mean=self.mean, covariance=self.covariance)
143 def __iter__(self) -> Iterator[FrontierPoint]:
144 """Iterate over all frontier points."""
145 yield from self.frontier
147 def __len__(self) -> int:
148 """Give number of frontier points."""
149 return len(self.frontier)
151 @property
152 def weights(self) -> np.ndarray:
153 """Matrix of weights. One row per point."""
154 return np.array([point.weights for point in self])
156 @property
157 def returns(self) -> np.ndarray:
158 """Vector of expected returns."""
159 return np.array([point.mean(self.mean) for point in self])
161 @property
162 def variance(self) -> np.ndarray:
163 """Vector of expected variances."""
164 return np.array([point.variance(self.covariance) for point in self])
166 @property
167 def sharpe_ratio(self) -> np.ndarray:
168 """Vector of expected Sharpe ratios."""
169 return self.returns / self.volatility
171 @property
172 def volatility(self) -> np.ndarray:
173 """Vector of expected volatilities."""
174 return np.sqrt(self.variance)
176 @property
177 def max_sharpe(self) -> tuple[float, np.ndarray]:
178 """Maximal Sharpe ratio on the frontier.
180 Returns:
181 Tuple of maximal Sharpe ratio and the weights to achieve it
183 """
185 def neg_sharpe(alpha: float, w_left: np.ndarray, w_right: np.ndarray) -> float:
186 # convex combination of left and right weights
187 weight = alpha * w_left + (1 - alpha) * w_right
188 # compute the variance
189 var = weight.T @ self.covariance @ weight
190 returns = self.mean.T @ weight
191 return -returns / np.sqrt(var)
193 sharpe_ratios = self.sharpe_ratio
195 # in which point is the maximal Sharpe ratio?
196 sr_position_max = np.argmax(self.sharpe_ratio)
198 # np.min only there for security...
199 right = np.min([sr_position_max + 1, len(self) - 1])
200 left = np.max([0, sr_position_max - 1])
202 # Look to the left and look to the right
204 if right > sr_position_max:
205 out = minimize(
206 neg_sharpe,
207 0.5,
208 args=(self.weights[sr_position_max], self.weights[right]),
209 bounds=((0, 1),),
210 )
211 var = out["x"][0]
212 w_right = var * self.weights[sr_position_max] + (1 - var) * self.weights[right]
213 sharpe_ratio_right = -out["fun"]
214 else:
215 w_right = self.weights[sr_position_max]
216 sharpe_ratio_right = sharpe_ratios[sr_position_max]
218 if left < sr_position_max:
219 out = minimize(
220 neg_sharpe,
221 0.5,
222 args=(self.weights[left], self.weights[sr_position_max]),
223 bounds=((0, 1),),
224 )
225 var = out["x"][0]
226 w_left = var * self.weights[left] + (1 - var) * self.weights[sr_position_max]
227 sharpe_ratio_left = -out["fun"]
228 else:
229 w_left = self.weights[sr_position_max]
230 sharpe_ratio_left = sharpe_ratios[sr_position_max]
232 if sharpe_ratio_left > sharpe_ratio_right:
233 return sharpe_ratio_left, w_left
235 return sharpe_ratio_right, w_right
237 def plot(self, volatility: bool = False, markers: bool = True) -> go.Figure:
238 """Plot the efficient frontier.
240 This function creates a line plot of the efficient frontier, with expected return
241 on the y-axis and either variance or volatility on the x-axis.
243 Args:
244 volatility: If True, plot volatility (standard deviation) on the x-axis.
245 If False, plot variance on the x-axis.
246 markers: If True, show markers at each point on the frontier.
248 Returns:
249 A plotly Figure object that can be displayed or saved.
251 """
252 fig = go.Figure()
254 x = self.volatility if volatility else self.variance
255 axis_title = "Expected volatility" if volatility else "Expected variance"
257 fig.add_trace(
258 go.Scatter(x=x, y=self.returns, mode="lines+markers" if markers else "lines", name="Efficient Frontier")
259 )
261 fig.update_layout(
262 xaxis_title=axis_title,
263 yaxis_title="Expected Return",
264 )
266 return fig