Coverage for src / cvx / simulator / utils / interpolation.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-10 05:38 +0000

1# Copyright 2023 Stanford University Convex Optimization Group 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14"""Interpolation utilities for time series data. 

15 

16This module provides functions for interpolating missing values in time series 

17and validating that time series don't have missing values in the middle. 

18""" 

19 

20from typing import cast 

21 

22import pandas as pd 

23import polars as pl 

24 

25 

26def interpolate(ts: pd.Series | pl.Series) -> pd.Series | pl.Series: 

27 """Interpolate missing values in a time series between the first and last valid indices. 

28 

29 This function fills forward (ffill) missing values in a time series, but only 

30 between the first and last valid indices. Values outside this range remain NaN/null. 

31 

32 Parameters 

33 ---------- 

34 ts : pd.Series or pl.Series 

35 The time series to interpolate 

36 

37 Returns: 

38 ------- 

39 pd.Series or pl.Series 

40 The interpolated time series 

41 

42 Examples: 

43 -------- 

44 >>> import pandas as pd 

45 >>> import numpy as np 

46 >>> ts = pd.Series([1, np.nan, np.nan, 4, 5]) 

47 >>> interpolate(ts) 

48 0 1.0 

49 1 1.0 

50 2 1.0 

51 3 4.0 

52 4 5.0 

53 dtype: float64 

54 

55 """ 

56 # Check if the input is a valid type 

57 if not isinstance(ts, pd.Series | pl.Series): 

58 raise TypeError(f"Expected pd.Series or pl.Series, got {type(ts)}") # noqa: TRY003 

59 

60 # If the input is a polars Series, use the polars-specific function 

61 if isinstance(ts, pl.Series): 

62 return interpolate_pl(ts) 

63 first = ts.first_valid_index() 

64 last = ts.last_valid_index() 

65 

66 if first is not None and last is not None: 

67 ts_slice = ts.loc[first:last] 

68 ts_slice = ts_slice.ffill() 

69 result = ts.copy() 

70 result.loc[first:last] = ts_slice 

71 return result 

72 return ts 

73 

74 

75def valid(ts: pd.Series | pl.Series) -> bool: 

76 """Check if a time series has no missing values between the first and last valid indices. 

77 

78 This function verifies that a time series doesn't have any NaN/null values in the middle. 

79 It's acceptable to have NaNs/nulls at the beginning or end of the series. 

80 

81 Parameters 

82 ---------- 

83 ts : pd.Series or pl.Series 

84 The time series to check 

85 

86 Returns: 

87 ------- 

88 bool 

89 True if the time series has no missing values between the first and last valid indices, 

90 False otherwise 

91 

92 Examples: 

93 -------- 

94 >>> import pandas as pd 

95 >>> import numpy as np 

96 >>> ts1 = pd.Series([np.nan, 1, 2, 3, np.nan]) # NaNs only at beginning and end 

97 >>> valid(ts1) 

98 True 

99 >>> ts2 = pd.Series([1, 2, np.nan, 4, 5]) # NaN in the middle 

100 >>> valid(ts2) 

101 False 

102 

103 """ 

104 # Check if the input is a valid type 

105 if not isinstance(ts, pd.Series | pl.Series): 

106 raise TypeError(f"Expected pd.Series or pl.Series, got {type(ts)}") # noqa: TRY003 

107 

108 # If the input is a polars Series, use the polars-specific function 

109 if isinstance(ts, pl.Series): 

110 return valid_pl(ts) 

111 # Check if the series with NaNs dropped has the same indices as the interpolated series with NaNs dropped 

112 # If they're the same, there are no NaNs in the middle of the series 

113 interpolated = cast(pd.Series, interpolate(ts)) 

114 return bool(ts.dropna().index.equals(interpolated.dropna().index)) 

115 

116 

117def interpolate_pl(ts: pl.Series) -> pl.Series: 

118 """Interpolate missing values in a polars time series between the first and last valid indices. 

119 

120 This function fills forward (ffill) missing values in a time series, but only 

121 between the first and last valid indices. Values outside this range remain null. 

122 

123 Parameters 

124 ---------- 

125 ts : pl.Series 

126 The time series to interpolate 

127 

128 Returns: 

129 ------- 

130 pl.Series 

131 The interpolated time series 

132 

133 Examples: 

134 -------- 

135 >>> import polars as pl 

136 >>> ts = pl.Series([1, None, None, 4, 5]) 

137 >>> interpolate_pl(ts) 

138 shape: (5,) 

139 Series: '' [i64] 

140 [ 

141 1 

142 1 

143 1 

144 4 

145 5 

146 ] 

147 

148 """ 

149 # Find first and last valid indices 

150 non_null_indices = ts.is_not_null().arg_true() 

151 

152 if len(non_null_indices) == 0: 

153 return ts 

154 

155 first = non_null_indices[0] 

156 last = non_null_indices[-1] 

157 

158 # Create a new series with the same length as the original 

159 values = ts.to_list() 

160 

161 # Fill forward within the slice between first and last valid indices 

162 current_value = None 

163 for i in range(first, last + 1): 

164 if values[i] is not None: 

165 current_value = values[i] 

166 elif current_value is not None: 

167 values[i] = current_value 

168 

169 # Create a new series with the filled values 

170 return pl.Series(values, dtype=ts.dtype) 

171 

172 

173def valid_pl(ts: pl.Series) -> bool: 

174 """Check if a polars time series has no missing values between the first and last valid indices. 

175 

176 This function verifies that a time series doesn't have any null values in the middle. 

177 It's acceptable to have nulls at the beginning or end of the series. 

178 

179 Parameters 

180 ---------- 

181 ts : pl.Series 

182 The time series to check 

183 

184 Returns: 

185 ------- 

186 bool 

187 True if the time series has no missing values between the first and last valid indices, 

188 False otherwise 

189 

190 Examples: 

191 -------- 

192 >>> import polars as pl 

193 >>> ts1 = pl.Series([None, 1, 2, 3, None]) # Nulls only at beginning and end 

194 >>> valid_pl(ts1) 

195 True 

196 >>> ts2 = pl.Series([1, 2, None, 4, 5]) # Null in the middle 

197 >>> valid_pl(ts2) 

198 False 

199 

200 """ 

201 # Get indices of non-null values 

202 non_null_indices = ts.is_not_null().arg_true() 

203 

204 if len(non_null_indices) <= 1: 

205 return True 

206 

207 # Check if the range of indices is continuous 

208 first = non_null_indices[0] 

209 last = non_null_indices[-1] 

210 expected_count = last - first + 1 

211 

212 # If all values between first and last valid indices are non-null, 

213 # then the count of non-null values should equal the range size 

214 return bool(len([i for i in non_null_indices if first <= i <= last]) == expected_count) 

215 

216 

217def interpolate_df_pl(df: pl.DataFrame) -> pl.DataFrame: 

218 """Interpolate missing values in a polars DataFrame between the first and last valid indices for each column. 

219 

220 This function applies interpolate_pl to each column of a DataFrame, 

221 filling forward (ffill) missing values in each column, but only 

222 between the first and last valid indices. Values outside this range remain null. 

223 

224 Parameters 

225 ---------- 

226 df : pl.DataFrame 

227 The DataFrame to interpolate 

228 

229 Returns: 

230 ------- 

231 pl.DataFrame 

232 The interpolated DataFrame 

233 

234 Examples: 

235 -------- 

236 >>> import polars as pl 

237 >>> df = pl.DataFrame({ 

238 ... 'A': [1.0, None, None, 4.0, 5.0], 

239 ... 'B': [None, 2.0, None, 4.0, None] 

240 ... }) 

241 >>> interpolate_df_pl(df) 

242 shape: (5, 2) 

243 ┌─────┬──────┐ 

244 │ A ┆ B │ 

245 │ --- ┆ --- │ 

246 │ f64 ┆ f64 │ 

247 ╞═════╪══════╡ 

248 │ 1.0 ┆ null │ 

249 │ 1.0 ┆ 2.0 │ 

250 │ 1.0 ┆ 2.0 │ 

251 │ 4.0 ┆ 4.0 │ 

252 │ 5.0 ┆ null │ 

253 └─────┴──────┘ 

254 

255 """ 

256 # Apply interpolate_pl to each column 

257 result = {} 

258 for col in df.columns: 

259 result[col] = interpolate_pl(df[col]) 

260 

261 return pl.DataFrame(result) 

262 

263 

264def valid_df_pl(df: pl.DataFrame) -> bool: 

265 """Check if a polars DataFrame has no missing values between the first and last valid indices for each column. 

266 

267 This function verifies that each column in the DataFrame doesn't have any null values in the middle. 

268 It's acceptable to have nulls at the beginning or end of each column. 

269 

270 Parameters 

271 ---------- 

272 df : pl.DataFrame 

273 The DataFrame to check 

274 

275 Returns: 

276 ------- 

277 bool 

278 True if all columns in the DataFrame have no missing values between their first and last valid indices, 

279 False otherwise 

280 

281 Examples: 

282 -------- 

283 >>> import polars as pl 

284 >>> df1 = pl.DataFrame({ 

285 ... 'A': [None, 1, 2, 3, None], # Nulls only at beginning and end 

286 ... 'B': [None, 2, 3, 4, None] # Nulls only at beginning and end 

287 ... }) 

288 >>> valid_df_pl(df1) 

289 True 

290 >>> df2 = pl.DataFrame({ 

291 ... 'A': [1, 2, None, 4, 5], # Null in the middle 

292 ... 'B': [1, 2, 3, 4, 5] # No nulls 

293 ... }) 

294 >>> valid_df_pl(df2) 

295 False 

296 

297 """ 

298 # Check each column 

299 return all(valid_pl(df[col]) for col in df.columns)