Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 37 additions & 53 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def _efficient_sample(
n: int,
columns: Optional[Union[List[str], Dict[str, str]]],
batch_size: int,
max_takes: int,
) -> Generator[pa.RecordBatch, None, None]:
"""Sample n records from the dataset.

Mirrors the Rust ``sample_fsl_uniform`` strategy: generate n uniformly
random indices, sort them, and take in large contiguous chunks (default
8192 rows per take). Sorting allows the underlying object store to merge
adjacent row reads into fewer, larger range requests, which drastically
reduces I/O latency on remote storage (e.g. S3).

Parameters
----------
dataset : lance.LanceDataset
Expand All @@ -61,55 +66,47 @@ def _efficient_sample(
columns : list[str]
The columns to load.
batch_size : int
The batch size to use when loading the data.
max_takes : int
The maximum number of takes to perform. This is used to limit the number of
random reads. Large enough value can give a good random sample without
having to issue too many random reads.
The batch size to use when yielding output RecordBatches.

Returns
-------
Generator of a RecordBatch.
"""
buf: list[pa.RecordBatch] = []
total_records = len(dataset)
assert total_records > n
chunk_size = total_records // max_takes
chunk_sample_size = n // max_takes

num_sampled = 0

for idx, i in enumerate(range(0, total_records, chunk_size)):
# If we have already sampled enough, break. This can happen if there
# is a remainder in the division.
if num_sampled >= n:
break
num_sampled += chunk_sample_size
indices = np.random.choice(total_records, n, replace=False)
indices.sort()

# If we are at the last chunk, we may not have enough records to sample.
local_size = min(chunk_size, total_records - i)
local_sample_size = min(chunk_sample_size, local_size)
LOGGER.info(
"Sampling %d rows from %d total (sorted random indices, chunk take)",
n,
total_records,
)

if local_sample_size < local_size:
# Add more randomness within each chunk, if there is room.
offset = i + np.random.randint(0, local_size - local_sample_size)
else:
offset = i
take_chunk_size = 8192
buf: list[pa.RecordBatch] = []

buf.extend(
dataset.take(
list(range(offset, offset + local_sample_size)),
columns=columns,
).to_batches()
for start in range(0, len(indices), take_chunk_size):
chunk = indices[start : start + take_chunk_size].tolist()
buf.extend(dataset.take(chunk, columns=columns).to_batches())
LOGGER.info(
"Sampled chunk %d/%d, rows %d-%d",
start // take_chunk_size + 1,
math.ceil(len(indices) / take_chunk_size),
chunk[0],
chunk[-1],
)
if idx % 50 == 0:
LOGGER.info("Sampled at offset=%s, len=%s", offset, chunk_sample_size)
if sum(len(b) for b in buf) >= batch_size:
while sum(len(b) for b in buf) >= batch_size:
tbl = pa.Table.from_batches(buf)
buf.clear()
tbl = tbl.combine_chunks()
yield tbl.to_batches()[0]
del tbl
batch_tbl = tbl.slice(0, batch_size).combine_chunks()
rest_tbl = tbl.slice(batch_size)
yield batch_tbl.to_batches()[0]
del batch_tbl
if rest_tbl.num_rows > 0:
buf.extend(rest_tbl.to_batches())
del rest_tbl, tbl
if buf:
tbl = pa.Table.from_batches(buf).combine_chunks()
yield tbl.to_batches()[0]
Expand All @@ -121,10 +118,10 @@ def _filtered_efficient_sample(
n: int,
columns: List[str],
batch_size: int,
target_takes: int,
filter: str,
) -> Generator[pa.RecordBatch, None, None]:
total_records = len(dataset)
target_takes = max(1, n // 32)
shard_size = math.ceil(n / target_takes)
num_shards = math.ceil(total_records / shard_size)

Expand Down Expand Up @@ -189,10 +186,7 @@ def maybe_sample(
batch_size : int, optional
The batch size to use when loading the data, by default 10240.
max_takes : int, optional
The maximum number of takes to perform, by default 2048.
This is employed to minimize the number of random reads necessary for sampling.
A sufficiently large value can provide an effective random sample without
the need for excessive random reads.
Deprecated and ignored. Kept for API compatibility only.
filter : str, optional
The filter to apply to the dataset, by default None. If a filter is provided,
then we will first load all row ids in memory and then batch through the ids
Expand All @@ -215,19 +209,9 @@ def maybe_sample(
columns=columns, batch_size=batch_size, filter=filt
)
elif filt is not None:
yield from _filtered_efficient_sample(
dataset, n, columns, batch_size, max_takes, filt
)
elif n > max_takes:
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
yield from _filtered_efficient_sample(dataset, n, columns, batch_size, filt)
else:
choices = np.random.choice(len(dataset), n, replace=False)
idx = 0
while idx < len(choices):
end = min(idx + batch_size, len(choices))
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
yield tbl.to_batches()[0]
idx += batch_size
yield from _efficient_sample(dataset, n, columns, batch_size)


T = TypeVar("T")
Expand Down
45 changes: 27 additions & 18 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def train_ivf_centroids_on_accelerator(
) -> Tuple[np.ndarray, Any]:
"""Use accelerator (GPU or MPS) to train kmeans."""

from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

metric_type = _normalize_metric_type(metric_type)
Expand All @@ -234,28 +233,35 @@ def train_ivf_centroids_on_accelerator(
else:
filt = None

LOGGER.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt)

ds = TorchDataset(
# Sample once into memory (mirrors the Rust CPU path), then reuse for
# both initial centroid selection and KMeans training.
LOGGER.info(
"Sampling %d vectors for IVF training from %s (filt=%s)",
sample_size,
dataset,
batch_size=k,
columns=[column],
samples=sample_size,
filter=filt,
filt,
)

init_centroids = next(iter(ds))
LOGGER.info("Done sampling: centroids shape: %s", init_centroids.shape)
from .sampler import maybe_sample

ds = TorchDataset(
dataset,
batch_size=20480,
columns=[column],
samples=sample_size,
filter=filt,
cache=True,
batches = list(maybe_sample(dataset, sample_size, [column], filt=filt))
tbl = pa.concat_tables(
[pa.Table.from_batches([b]) for b in batches]
).combine_chunks()
training_data = tbl.column(column).to_numpy(zero_copy_only=False)
training_data = np.stack(training_data).astype(np.float32)
del tbl, batches

LOGGER.info(
"Sampled %d vectors into memory (%.1f MB)",
training_data.shape[0],
training_data.nbytes / 1024 / 1024,
)

# Pick k random rows as initial centroids.
init_indices = np.random.choice(training_data.shape[0], k, replace=False)
init_centroids = torch.from_numpy(training_data[init_indices])

LOGGER.info("Training IVF partitions using GPU(%s)", accelerator)
kmeans = KMeans(
k,
Expand All @@ -264,7 +270,10 @@ def train_ivf_centroids_on_accelerator(
device=accelerator,
centroids=init_centroids,
)
kmeans.fit(ds)
# Pass the full training data as a Tensor so KMeans iterates in memory
# (no disk cache, no re-sampling).
kmeans.fit(torch.from_numpy(training_data))
del training_data

centroids = (
kmeans.centroids.cpu().numpy().astype(vector_value_type.to_pandas_dtype())
Expand Down
5 changes: 2 additions & 3 deletions python/python/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"master_addr": "127.0.0.1",
"seed": 42,
"test_shard_ratio": 0.5,
"max_takes_factor": 0.1,
}


Expand Down Expand Up @@ -270,8 +269,8 @@ def test_sample_dataset(tmp_path: Path, nrows: int):
assert simple_scan[0].schema == pa.schema([pa.field("vec", fsl.type)])
assert simple_scan[0].num_rows == min(nrows, 128)

# Random path.
large_scan = list(maybe_sample(ds, 128, ["vec"], max_takes=32))
# Sorted-index take path (n < len(dataset)).
large_scan = list(maybe_sample(ds, 128, ["vec"]))
assert len(large_scan) == 1
assert isinstance(large_scan[0], pa.RecordBatch)
assert large_scan[0].schema == pa.schema([pa.field("vec", fsl.type)])
Expand Down
2 changes: 1 addition & 1 deletion python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def iter_over_dataset(tmp_path):
assert batch["vec"].shape[1] == 32
assert total_rows == 1024

# test when sample size is greater than max_takes
# test larger sample size
torch_ds = LanceDataset(
ds,
batch_size=256,
Expand Down
Loading