Skip to content

Commit 262eac7

Browse files
duxiao1212facebook-github-bot
authored andcommitted
feat: Impl sort key for LocalShuffleReader (#26620)
Summary: Implement sorted shuffle k-way merge for LocalShuffleReader, when it's sortedShuffle. Added k-way merge support using TreeOfLosers to efficiently merge multiple sorted shuffle files. The reader streams data from sorted files and returns merged results in sorted order. Reviewed By: tanjialiang Differential Revision: D86888221
1 parent a070982 commit 262eac7

File tree

3 files changed

+411
-38
lines changed

3 files changed

+411
-38
lines changed

presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp

Lines changed: 218 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,102 @@
1515
#include "presto_cpp/external/json/nlohmann/json.hpp"
1616
#include "presto_cpp/main/common/Configs.h"
1717

18-
#include <folly/lang/Bits.h>
18+
#include "velox/common/file/FileInputStream.h"
1919

20-
using namespace facebook::velox::exec;
21-
using namespace facebook::velox;
20+
#include <boost/range/algorithm/sort.hpp>
2221

2322
namespace facebook::presto::operators {
2423

2524
using json = nlohmann::json;
2625

2726
namespace {
2827

28+
using TStreamIdx = uint16_t;
29+
30+
/// SortedFileInputStream reads sorted (key, data) pairs from a single
31+
/// shuffle file with buffered I/O. It extends FileInputStream for efficient
32+
/// buffered I/O and implements MergeStream interface for k-way merge.
33+
class SortedFileInputStream final : public velox::common::FileInputStream,
34+
public velox::MergeStream {
35+
public:
36+
SortedFileInputStream(
37+
const std::string& filePath,
38+
TStreamIdx streamIdx,
39+
velox::memory::MemoryPool* pool,
40+
size_t bufferSize = kDefaultInputStreamBufferSize)
41+
: velox::common::FileInputStream(
42+
velox::filesystems::getFileSystem(filePath, nullptr)
43+
->openFileForRead(filePath),
44+
bufferSize,
45+
pool),
46+
streamIdx_(streamIdx) {
47+
next();
48+
}
49+
50+
~SortedFileInputStream() override = default;
51+
52+
bool next() {
53+
if (atEnd()) {
54+
currentKey_ = {};
55+
currentData_ = {};
56+
keyStorage_.clear();
57+
dataStorage_.clear();
58+
return false;
59+
}
60+
const TRowSize keySize = folly::Endian::big(read<TRowSize>());
61+
const TRowSize dataSize = folly::Endian::big(read<TRowSize>());
62+
63+
currentKey_ = nextStringView(keySize, keyStorage_);
64+
currentData_ = nextStringView(dataSize, dataStorage_);
65+
return true;
66+
}
67+
68+
std::string_view currentKey() const {
69+
return currentKey_;
70+
}
71+
72+
std::string_view currentData() const {
73+
return currentData_;
74+
}
75+
76+
bool hasData() const override {
77+
return !currentData_.empty() || !atEnd();
78+
}
79+
80+
bool operator<(const velox::MergeStream& other) const override {
81+
const auto* otherReader = static_cast<const SortedFileInputStream*>(&other);
82+
if (currentKey_ != otherReader->currentKey_) {
83+
return compareKeys(currentKey_, otherReader->currentKey_);
84+
}
85+
return streamIdx_ < otherReader->streamIdx_;
86+
}
87+
88+
private:
89+
// Returns string_view using zero-copy when data fits in buffer,
90+
// otherwise copies to storage when crossing buffer boundaries.
91+
std::string_view nextStringView(TRowSize size, std::string& storage) {
92+
if (size == 0) {
93+
return {};
94+
}
95+
auto view = nextView(size);
96+
if (view.size() == size) {
97+
return view;
98+
}
99+
storage.resize(size);
100+
std::memcpy(storage.data(), view.data(), view.size());
101+
readBytes(
102+
reinterpret_cast<uint8_t*>(storage.data()) + view.size(),
103+
size - view.size());
104+
return std::string_view(storage);
105+
}
106+
107+
const TStreamIdx streamIdx_;
108+
std::string_view currentKey_;
109+
std::string_view currentData_;
110+
std::string keyStorage_;
111+
std::string dataStorage_;
112+
};
113+
29114
std::vector<RowMetadata>
30115
extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
31116
std::vector<RowMetadata> rows;
@@ -91,13 +176,9 @@ extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
91176

92177
inline std::string_view
93178
extractRowData(const RowMetadata& row, const char* buffer, bool sortedShuffle) {
94-
if (sortedShuffle) {
95-
const size_t dataOffset = row.rowStart + (kUint32Size * 2) + row.keySize;
96-
return {buffer + dataOffset, row.dataSize};
97-
} else {
98-
const size_t dataOffset = row.rowStart + kUint32Size;
99-
return {buffer + dataOffset, row.dataSize};
100-
}
179+
const auto dataOffset = row.rowStart +
180+
(sortedShuffle ? (kUint32Size * 2) + row.keySize : kUint32Size);
181+
return {buffer + dataOffset, row.dataSize};
101182
}
102183

103184
std::vector<RowMetadata> extractAndSortRowMetadata(
@@ -106,10 +187,8 @@ std::vector<RowMetadata> extractAndSortRowMetadata(
106187
bool sortedShuffle) {
107188
auto rows = extractRowMetadata(buffer, bufferSize, sortedShuffle);
108189
if (!rows.empty() && sortedShuffle) {
109-
std::sort(
110-
rows.begin(),
111-
rows.end(),
112-
[buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
190+
boost::range::sort(
191+
rows, [buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
113192
const char* lhsKey = buffer + lhs.rowStart + (kUint32Size * 2);
114193
const char* rhsKey = buffer + rhs.rowStart + (kUint32Size * 2);
115194
return compareKeys(
@@ -276,10 +355,11 @@ void LocalShuffleWriter::collect(
276355
sortedShuffle_ || key.empty(),
277356
"key '{}' must be empty for non-sorted shuffle",
278357
key);
358+
279359
const auto rowSize = this->rowSize(key.size(), data.size());
280360
auto& buffer = inProgressPartitions_[partition];
281361
if (buffer == nullptr) {
282-
buffer = AlignedBuffer::allocate<char>(
362+
buffer = velox::AlignedBuffer::allocate<char>(
283363
std::max(static_cast<uint64_t>(rowSize), maxBytesPerPartition_),
284364
pool_,
285365
0);
@@ -319,31 +399,105 @@ LocalShuffleReader::LocalShuffleReader(
319399
fileSystem_ = velox::filesystems::getFileSystem(rootPath_, nullptr);
320400
}
321401

322-
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
323-
LocalShuffleReader::next(uint64_t maxBytes) {
324-
if (readPartitionFiles_.empty()) {
325-
readPartitionFiles_ = getReadPartitionFiles();
402+
void LocalShuffleReader::initialize() {
403+
VELOX_CHECK(!initialized_, "LocalShuffleReader already initialized");
404+
readPartitionFiles_ = getReadPartitionFiles();
405+
406+
if (sortedShuffle_ && !readPartitionFiles_.empty()) {
407+
std::vector<std::unique_ptr<velox::MergeStream>> streams;
408+
streams.reserve(readPartitionFiles_.size());
409+
TStreamIdx streamIdx = 0;
410+
for (const auto& filename : readPartitionFiles_) {
411+
VELOX_CHECK(
412+
!filename.empty(),
413+
"Invalid empty shuffle file path for query {}, partitions: [{}]",
414+
queryId_,
415+
folly::join(", ", partitionIds_));
416+
auto reader =
417+
std::make_unique<SortedFileInputStream>(filename, streamIdx, pool_);
418+
if (reader->hasData()) {
419+
streams.push_back(std::move(reader));
420+
++streamIdx;
421+
}
422+
}
423+
if (!streams.empty()) {
424+
merge_ =
425+
std::make_unique<velox::TreeOfLosers<velox::MergeStream, uint16_t>>(
426+
std::move(streams));
427+
}
428+
}
429+
430+
initialized_ = true;
431+
}
432+
433+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextSorted(
434+
uint64_t maxBytes) {
435+
std::vector<std::unique_ptr<ReadBatch>> batches;
436+
437+
if (merge_ == nullptr) {
438+
return batches;
439+
}
440+
441+
auto batchBuffer = velox::AlignedBuffer::allocate<char>(maxBytes, pool_, 0);
442+
std::vector<std::string_view> rows;
443+
uint64_t bufferUsed = 0;
444+
445+
while (auto* stream = merge_->next()) {
446+
auto* reader = dynamic_cast<SortedFileInputStream*>(stream);
447+
const auto data = reader->currentData();
448+
449+
if (bufferUsed + data.size() > maxBytes) {
450+
if (bufferUsed > 0) {
451+
batches.push_back(
452+
std::make_unique<ReadBatch>(
453+
std::move(rows), std::move(batchBuffer)));
454+
return batches;
455+
}
456+
// Single row exceeds buffer - allocate larger buffer
457+
batchBuffer = velox::AlignedBuffer::allocate<char>(data.size(), pool_, 0);
458+
bufferUsed = 0;
459+
}
460+
461+
char* writePos = batchBuffer->asMutable<char>() + bufferUsed;
462+
if (!data.empty()) {
463+
memcpy(writePos, data.data(), data.size());
464+
}
465+
466+
rows.emplace_back(batchBuffer->as<char>() + bufferUsed, data.size());
467+
bufferUsed += data.size();
468+
reader->next();
469+
}
470+
471+
if (!rows.empty()) {
472+
batches.push_back(
473+
std::make_unique<ReadBatch>(std::move(rows), std::move(batchBuffer)));
326474
}
327475

476+
return batches;
477+
}
478+
479+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextUnsorted(
480+
uint64_t maxBytes) {
328481
std::vector<std::unique_ptr<ReadBatch>> batches;
329482
uint64_t totalBytes{0};
330-
// Read files until we reach maxBytes limit or run out of files.
483+
331484
while (readPartitionFileIndex_ < readPartitionFiles_.size()) {
332485
const auto filename = readPartitionFiles_[readPartitionFileIndex_];
333486
auto file = fileSystem_->openFileForRead(filename);
334487
const auto fileSize = file->size();
335488

336-
// Stop if adding this file would exceed maxBytes (unless we haven't read
337-
// any files yet)
489+
// TODO: Refactor to use streaming I/O with bounded buffer size instead of
490+
// loading entire files into memory at once. A streaming approach would
491+
// reduce peak memory consumption and enable processing arbitrarily large
492+
// shuffle files while maintaining constant memory usage.
338493
if (!batches.empty() && totalBytes + fileSize > maxBytes) {
339494
break;
340495
}
341496

342-
auto buffer = AlignedBuffer::allocate<char>(fileSize, pool_, 0);
497+
auto buffer = velox::AlignedBuffer::allocate<char>(fileSize, pool_, 0);
343498
file->pread(0, fileSize, buffer->asMutable<void>());
344499
++readPartitionFileIndex_;
345500

346-
// Parse the buffer to extract individual rows
347501
const char* data = buffer->as<char>();
348502
const auto parsedRows = extractRowMetadata(data, fileSize, sortedShuffle_);
349503
std::vector<std::string_view> rows;
@@ -357,7 +511,17 @@ LocalShuffleReader::next(uint64_t maxBytes) {
357511
std::make_unique<ReadBatch>(std::move(rows), std::move(buffer)));
358512
}
359513

360-
return folly::makeSemiFuture(std::move(batches));
514+
return batches;
515+
}
516+
517+
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
518+
LocalShuffleReader::next(uint64_t maxBytes) {
519+
VELOX_CHECK(
520+
initialized_,
521+
"LocalShuffleReader::initialize() must be called before next()");
522+
523+
return folly::makeSemiFuture(
524+
sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes));
361525
}
362526

363527
void LocalShuffleReader::noMoreData(bool success) {
@@ -403,12 +567,26 @@ std::shared_ptr<ShuffleReader> LocalPersistentShuffleFactory::createReader(
403567
velox::memory::MemoryPool* pool) {
404568
const operators::LocalShuffleReadInfo readInfo =
405569
operators::LocalShuffleReadInfo::deserialize(serializedStr);
406-
return std::make_shared<LocalShuffleReader>(
570+
// Check if sortedShuffle field is present in the JSON
571+
bool sortedShuffle = false;
572+
try {
573+
const auto jsonReadInfo = json::parse(serializedStr);
574+
if (jsonReadInfo.contains("sortedShuffle")) {
575+
jsonReadInfo.at("sortedShuffle").get_to(sortedShuffle);
576+
}
577+
} catch (const std::exception& /*e*/) {
578+
// If parsing fails or field doesn't exist, default to false
579+
sortedShuffle = false;
580+
}
581+
582+
auto reader = std::make_shared<LocalShuffleReader>(
407583
readInfo.rootPath,
408584
readInfo.queryId,
409585
readInfo.partitionIds,
410-
/*sortShuffle=*/false, // default to false for now
586+
sortedShuffle,
411587
pool);
588+
reader->initialize();
589+
return reader;
412590
}
413591

414592
std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
@@ -418,13 +596,25 @@ std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
418596
SystemConfig::instance()->localShuffleMaxPartitionBytes();
419597
const operators::LocalShuffleWriteInfo writeInfo =
420598
operators::LocalShuffleWriteInfo::deserialize(serializedStr);
599+
// Check if sortedShuffle field is present in the JSON
600+
bool sortedShuffle = false;
601+
try {
602+
const auto jsonWriteInfo = json::parse(serializedStr);
603+
if (jsonWriteInfo.contains("sortedShuffle")) {
604+
jsonWriteInfo.at("sortedShuffle").get_to(sortedShuffle);
605+
}
606+
} catch (const std::exception& /*e*/) {
607+
// If parsing fails or field doesn't exist, default to false
608+
sortedShuffle = false;
609+
}
610+
421611
return std::make_shared<LocalShuffleWriter>(
422612
writeInfo.rootPath,
423613
writeInfo.queryId,
424614
writeInfo.shuffleId,
425615
writeInfo.numPartitions,
426616
maxBytesPerPartition,
427-
/*sortedShuffle=*/false, // default to false for now
617+
sortedShuffle,
428618
pool);
429619
}
430620

@@ -436,5 +626,4 @@ std::vector<RowMetadata> testingExtractRowMetadata(
436626
bool sortedShuffle) {
437627
return extractRowMetadata(buffer, bufferSize, sortedShuffle);
438628
}
439-
440629
} // namespace facebook::presto::operators

0 commit comments

Comments
 (0)