Coverage for src/cvx/bson/io.py: 87%

31 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-17 08:01 +0000

1# Copyright 2023 Stanford University Convex Optimization Group 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14"""Encoding and decoding utilities for numpy arrays and dataframes in BSON format.""" 

15 

16import json 

17from io import BytesIO 

18from typing import Any 

19 

20import numpy as np 

21import pandas as pd 

22import polars as pl 

23import pyarrow as pa 

24 

25 

26def encode(data: np.ndarray | pd.DataFrame | pl.DataFrame) -> Any: 

27 """Encode a numpy array or a pandas DataFrame. 

28 

29 Args: 

30 data: The numpy array or pandas DataFrame 

31 

32 Returns: object converted into bytes 

33 """ 

34 if isinstance(data, np.ndarray): 

35 tensor = pa.Tensor.from_numpy(obj=data) 

36 buffer = pa.BufferOutputStream() 

37 pa.ipc.write_tensor(tensor, buffer) 

38 return bytes(buffer.getvalue().to_pybytes()) 

39 

40 if isinstance(data, pd.DataFrame): 

41 return data.to_parquet() 

42 

43 if isinstance(data, pl.DataFrame): 

44 result = data.write_ipc(file=None) 

45 result.seek(0) 

46 return result.read() 

47 

48 converted = json.dumps(data).encode(encoding="utf-8") 

49 arr = bytes("cvx", "utf-8") 

50 return arr + converted 

51 

52 # return bytes. 

53 # print(encoded_tuple) 

54 # decoded_color = encoded_color.decode() 

55 # original_form = json.load(decoded_color) 

56 # return 

57 

58 # raise TypeError(f"Invalid Datatype {type(data)}") 

59 

60 

61def decode(data: bytes) -> np.ndarray | pd.DataFrame | pl.DataFrame: 

62 """Decode the bytes back into numpy array or pandas DataFrame. 

63 

64 Args: 

65 data: bytes 

66 

67 Returns: 

68 The array or the frame 

69 """ 

70 # reader the first few bytes 

71 header = data[:3] 

72 

73 # ARR indicates a pl.DataFrame 

74 if header == b"ARR": 

75 return pl.read_ipc(data) 

76 

77 # PAR indicates a pd.DataFrame 

78 if header == b"PAR": 

79 return pd.read_parquet(BytesIO(data)) 

80 

81 if header == b"cvx": 

82 return json.loads(data[3:].decode()) 

83 

84 # if still here we try numpy 

85 return pa.ipc.read_tensor(data).to_numpy()