Coverage for src/cvx/bson/io.py: 87%
31 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-17 08:01 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-17 08:01 +0000
1# Copyright 2023 Stanford University Convex Optimization Group
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Encoding and decoding utilities for numpy arrays and dataframes in BSON format."""
16import json
17from io import BytesIO
18from typing import Any
20import numpy as np
21import pandas as pd
22import polars as pl
23import pyarrow as pa
26def encode(data: np.ndarray | pd.DataFrame | pl.DataFrame) -> Any:
27 """Encode a numpy array or a pandas DataFrame.
29 Args:
30 data: The numpy array or pandas DataFrame
32 Returns: object converted into bytes
33 """
34 if isinstance(data, np.ndarray):
35 tensor = pa.Tensor.from_numpy(obj=data)
36 buffer = pa.BufferOutputStream()
37 pa.ipc.write_tensor(tensor, buffer)
38 return bytes(buffer.getvalue().to_pybytes())
40 if isinstance(data, pd.DataFrame):
41 return data.to_parquet()
43 if isinstance(data, pl.DataFrame):
44 result = data.write_ipc(file=None)
45 result.seek(0)
46 return result.read()
48 converted = json.dumps(data).encode(encoding="utf-8")
49 arr = bytes("cvx", "utf-8")
50 return arr + converted
52 # return bytes.
53 # print(encoded_tuple)
54 # decoded_color = encoded_color.decode()
55 # original_form = json.load(decoded_color)
56 # return
58 # raise TypeError(f"Invalid Datatype {type(data)}")
61def decode(data: bytes) -> np.ndarray | pd.DataFrame | pl.DataFrame:
62 """Decode the bytes back into numpy array or pandas DataFrame.
64 Args:
65 data: bytes
67 Returns:
68 The array or the frame
69 """
70 # reader the first few bytes
71 header = data[:3]
73 # ARR indicates a pl.DataFrame
74 if header == b"ARR":
75 return pl.read_ipc(data)
77 # PAR indicates a pd.DataFrame
78 if header == b"PAR":
79 return pd.read_parquet(BytesIO(data))
81 if header == b"cvx":
82 return json.loads(data[3:].decode())
84 # if still here we try numpy
85 return pa.ipc.read_tensor(data).to_numpy()