Skip to content
65 changes: 23 additions & 42 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,49 +818,30 @@ def _materialize_local(
total_rows = result_batches.approx_total_rows
# Remove downsampling config from subsequent invocations, as otherwise could result in many
# iterations if downsampling undershoots
return self._downsample(
total_rows=total_rows,
sampling_method=sample_config.sampling_method,
fraction=fraction,
random_state=sample_config.random_state,
)._materialize_local(
MaterializationOptions(ordered=materialize_options.ordered)
)
else:
df = result_batches.to_pandas()
df = self._copy_index_to_pandas(df)
df.set_axis(self.column_labels, axis=1, copy=False)
return df, execute_result.query_job

def _downsample(
self, total_rows: int, sampling_method: str, fraction: float, random_state
) -> Block:
# either selecting fraction or number of rows
if sampling_method == _HEAD:
filtered_block = self.slice(stop=int(total_rows * fraction))
return filtered_block
elif (sampling_method == _UNIFORM) and (random_state is None):
filtered_expr = self.expr._uniform_sampling(fraction)
block = Block(
filtered_expr,
index_columns=self.index_columns,
column_labels=self.column_labels,
index_labels=self.index.names,
)
return block
elif sampling_method == _UNIFORM:
block = self.split(
fracs=(fraction,),
random_state=random_state,
sort=False,
)[0]
return block
if sample_config.sampling_method == "head":
# Just truncates the result iterator without a follow-up query
raw_df = result_batches.to_pandas(limit=int(total_rows * fraction))
elif (
sample_config.sampling_method == "uniform"
and sample_config.random_state is None
):
# Pushes sample into result without new query
sampled_batches = execute_result.batches(sample_rate=fraction)
raw_df = sampled_batches.to_pandas()
else: # uniform sample with random state requires a full follow-up query
down_sampled_block = self.split(
fracs=(fraction,),
random_state=sample_config.random_state,
sort=False,
)[0]
return down_sampled_block._materialize_local(
MaterializationOptions(ordered=materialize_options.ordered)
)
else:
# This part should never be called, just in case.
raise NotImplementedError(
f"The downsampling method {sampling_method} is not implemented, "
f"please choose from {','.join(_SAMPLING_METHODS)}."
)
raw_df = result_batches.to_pandas()
df = self._copy_index_to_pandas(raw_df)
df.set_axis(self.column_labels, axis=1, copy=False)
return df, execute_result.query_job

def split(
self,
Expand Down
13 changes: 12 additions & 1 deletion bigframes/core/bq_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,22 @@ def get_arrow_batches(
columns: Sequence[str],
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
project_id: str,
sample_rate: Optional[float] = None,
) -> ReadResult:
table_mod_options = {}
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}

predicates = []
if data.sql_predicate:
read_options_dict["row_restriction"] = data.sql_predicate
predicates.append(data.sql_predicate)
if sample_rate is not None:
assert isinstance(sample_rate, float)
predicates.append(f"RAND() < {sample_rate}")

if predicates:
full_predicates = " AND ".join(f"( {pred} )" for pred in predicates)
read_options_dict["row_restriction"] = full_predicates

read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)

if data.at_time:
Expand Down
11 changes: 10 additions & 1 deletion bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import uuid

import geopandas # type: ignore
import numpy
import numpy as np
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -124,13 +125,21 @@ def to_arrow(
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
sample_rate: Optional[float] = None,
max_chunksize: Optional[int] = None,
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
if geo_format != "wkt":
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
assert json_type == "string"

batches = self.data.to_batches(max_chunksize=max_chunksize)
data = self.data

# This exists for symmetry with remote sources, but sampling local data like this shouldn't really happen
if sample_rate is not None:
to_take = numpy.random.rand(data.num_rows) < sample_rate
data = data.filter(to_take)

batches = data.to_batches(max_chunksize=max_chunksize)
schema = self.data.schema
if duration_type == "int":
schema = _schema_durations_to_ints(schema)
Expand Down
26 changes: 15 additions & 11 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:

yield batch

def to_arrow_table(self) -> pyarrow.Table:
def to_arrow_table(self, limit: Optional[int] = None) -> pyarrow.Table:
# Need to provide schema if no result rows, as arrow can't infer
# If ther are rows, it is safest to infer schema from batches.
# Any discrepencies between predicted schema and actual schema will produce errors.
Expand All @@ -97,18 +97,21 @@ def to_arrow_table(self) -> pyarrow.Table:
peek_value = list(peek_it)
# TODO: Enforce our internal schema on the table for consistency
if len(peek_value) > 0:
return pyarrow.Table.from_batches(
itertools.chain(peek_value, batches), # reconstruct
)
batches = itertools.chain(peek_value, batches) # reconstruct
if limit:
batches = pyarrow_utils.truncate_pyarrow_iterable(
batches, max_results=limit
)
return pyarrow.Table.from_batches(batches)
else:
try:
return self._schema.to_pyarrow().empty_table()
except pa.ArrowNotImplementedError:
# Bug with some pyarrow versions, empty_table only supports base storage types, not extension types.
return self._schema.to_pyarrow(use_storage_types=True).empty_table()

def to_pandas(self) -> pd.DataFrame:
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema)
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
return io_pandas.arrow_to_pandas(self.to_arrow_table(limit=limit), self._schema)

def to_pandas_batches(
self, page_size: Optional[int] = None, max_results: Optional[int] = None
Expand Down Expand Up @@ -158,7 +161,7 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
...

@abc.abstractmethod
def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
...

@property
Expand Down Expand Up @@ -200,9 +203,9 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> bigframes.core.schema.ArraySchema:
return self._data.schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
return ResultsIterator(
iter(self._data.to_arrow()[1]),
iter(self._data.to_arrow(sample_rate=sample_rate)[1]),
self.schema,
self._data.metadata.row_count,
self._data.metadata.total_bytes,
Expand All @@ -226,7 +229,7 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> bigframes.core.schema.ArraySchema:
return self._schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
return ResultsIterator(iter([]), self.schema, 0, 0)


Expand Down Expand Up @@ -260,12 +263,13 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
source_ids = [selection[0] for selection in self._selected_fields]
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
read_batches = bq_data.get_arrow_batches(
self._data,
[x[0] for x in self._selected_fields],
self._storage_client,
self._project_id,
sample_rate=sample_rate,
)
arrow_batches: Iterator[pa.RecordBatch] = map(
functools.partial(
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/test_anywidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> Any:
return schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate=None) -> ResultsIterator:
return ResultsIterator(
arrow_batches_val,
self.schema,
Expand Down
6 changes: 3 additions & 3 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4524,7 +4524,7 @@ def test_df_kurt(scalars_dfs):
"n_default",
],
)
def test_sample(scalars_dfs, frac, n, random_state):
def test_df_to_pandas_sample(scalars_dfs, frac, n, random_state):
scalars_df, _ = scalars_dfs
df = scalars_df.sample(frac=frac, n=n, random_state=random_state)
bf_result = df.to_pandas()
Expand All @@ -4535,15 +4535,15 @@ def test_sample(scalars_dfs, frac, n, random_state):
assert bf_result.shape[1] == scalars_df.shape[1]


def test_sample_determinism(penguins_df_default_index):
def test_df_to_pandas_sample_determinism(penguins_df_default_index):
df = penguins_df_default_index.sample(n=100, random_state=12345).head(15)
bf_result = df.to_pandas()
bf_result2 = df.to_pandas()

pandas.testing.assert_frame_equal(bf_result, bf_result2)


def test_sample_raises_value_error(scalars_dfs):
def test_df_to_pandas_sample_raises_value_error(scalars_dfs):
scalars_df, _ = scalars_dfs
with pytest.raises(
ValueError, match="Only one of 'n' or 'frac' parameter can be specified."
Expand Down
Loading