Skip to content

Commit 3784486

Browse files
duxiao1212facebook-github-bot
authored andcommitted
misc: Refactor ShuffleTest (prestodb#26643)
Summary: Pull Request resolved: prestodb#26643 Completely replace TestShuffleWriter/Reader with local shuffle, and refactor the ShuffleTest class Differential Revision: D87160791
1 parent ebee40b commit 3784486

File tree

4 files changed

+861
-1061
lines changed

4 files changed

+861
-1061
lines changed

presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,30 @@ void LocalShuffleWriter::collect(
353353
int32_t partition,
354354
std::string_view key,
355355
std::string_view data) {
356+
LOG(INFO) << "LocalShuffleWriter::collect START partition=" << partition
357+
<< " sortedShuffle=" << sortedShuffle_
358+
<< " keySize=" << key.size()
359+
<< " dataSize=" << data.size();
356360
VELOX_CHECK_LT(partition, numPartitions_);
361+
// For non-sorted shuffle, key must be empty
362+
// For sorted shuffle, key is used for sorting
357363
VELOX_CHECK(
358364
sortedShuffle_ || key.empty(),
359-
"key '{}' must be empty for non-sorted shuffle",
360-
key);
365+
"key must be empty for non-sorted shuffle, got key size {}",
366+
key.size());
367+
368+
// Testvalue injection point for exception testing
369+
velox::common::testutil::TestValue::adjust(
370+
"facebook::presto::operators::LocalShuffleWriter::collect", this);
371+
372+
// Log the actual data content for debugging
373+
if (data.size() >= 16) {
374+
LOG(INFO) << "LocalShuffleWriter::collect data first 16 bytes (hex): "
375+
<< folly::hexlify(std::string_view(data.data(), std::min(data.size(), size_t(16))));
376+
}
377+
361378
const auto rowSize = this->rowSize(key.size(), data.size());
379+
LOG(INFO) << "LocalShuffleWriter::collect rowSize=" << rowSize << " keySize=" << key.size();
362380

363381
auto& buffer = inProgressPartitions_[partition];
364382
if (buffer == nullptr) {
@@ -377,15 +395,19 @@ void LocalShuffleWriter::collect(
377395
}
378396

379397
void LocalShuffleWriter::noMoreData(bool success) {
398+
LOG(INFO) << "LocalShuffleWriter::noMoreData START success=" << success
399+
<< " numPartitions=" << numPartitions_;
380400
// Delete all shuffle files on failure.
381401
if (!success) {
382402
cleanup();
383403
}
384404
for (auto i = 0; i < numPartitions_; ++i) {
385405
if (inProgressSizes_[i] > 0) {
406+
LOG(INFO) << "LocalShuffleWriter::noMoreData writing block for partition " << i;
386407
writeBlock(i);
387408
}
388409
}
410+
LOG(INFO) << "LocalShuffleWriter::noMoreData COMPLETE";
389411
}
390412

391413
LocalShuffleReader::LocalShuffleReader(
@@ -403,9 +425,11 @@ LocalShuffleReader::LocalShuffleReader(
403425
}
404426

405427
void LocalShuffleReader::initialize() {
428+
LOG(INFO) << "LocalShuffleReader::initialize START sortedShuffle=" << sortedShuffle_;
406429
VELOX_CHECK(!initialized_, "LocalShuffleReader already initialized");
407430

408431
readPartitionFiles_ = getReadPartitionFiles();
432+
LOG(INFO) << "LocalShuffleReader::initialize found " << readPartitionFiles_.size() << " files";
409433

410434
if (sortedShuffle_ && !readPartitionFiles_.empty()) {
411435
std::vector<std::unique_ptr<velox::MergeStream>> streams;
@@ -432,13 +456,15 @@ void LocalShuffleReader::initialize() {
432456
}
433457

434458
initialized_ = true;
459+
LOG(INFO) << "LocalShuffleReader::initialize COMPLETE";
435460
}
436461

437462
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextSorted(
438463
uint64_t maxBytes) {
439464
std::vector<std::unique_ptr<ReadBatch>> batches;
440465

441466
if (merge_ == nullptr) {
467+
LOG(INFO) << "LocalShuffleReader::nextSorted merge is null, returning empty";
442468
return batches;
443469
}
444470

@@ -449,37 +475,42 @@ std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextSorted(
449475
while (auto* stream = merge_->next()) {
450476
auto* reader = dynamic_cast<SortedFileInputStream*>(stream);
451477
const auto data = reader->currentData();
452-
const auto rowSize = kUint32Size + data.size();
478+
const auto key = reader->currentKey();
479+
480+
LOG(INFO) << "LocalShuffleReader::nextSorted processing row: keySize=" << key.size()
481+
<< " dataSize=" << data.size() << " bufferUsed=" << bufferUsed;
453482

454-
// With the current row the bufferUsed byte will exceed the maxBytes
455-
if (bufferUsed + rowSize > maxBytes) {
483+
// For sorted shuffle: data is already CompactRow serialized data
484+
// Return it as-is, NO size prefix (same as CoscoShuffleReader)
485+
if (bufferUsed + data.size() > maxBytes) {
456486
if (bufferUsed > 0) {
457-
// We have some rows already, return them to release the memory
487+
LOG(INFO) << "LocalShuffleReader::nextSorted returning batch with " << rows.size() << " rows";
458488
batches.push_back(
459489
std::make_unique<ReadBatch>(std::move(rows), std::move(batchBuffer)));
460490
return batches;
461491
}
462-
// Single row exceeds buffer - allocate larger buffer for this row
463-
batchBuffer = velox::AlignedBuffer::allocate<char>(rowSize, pool_, 0);
492+
// Single row exceeds buffer - allocate larger buffer
493+
batchBuffer = velox::AlignedBuffer::allocate<char>(data.size(), pool_, 0);
464494
bufferUsed = 0;
465495
}
466496

467-
// Write row: [dataSize][data]
497+
// Copy data as-is without size prefix
468498
char* writePos = batchBuffer->asMutable<char>() + bufferUsed;
469-
*reinterpret_cast<TRowSize*>(writePos) =
470-
folly::Endian::big(static_cast<TRowSize>(data.size()));
471-
472499
if (!data.empty()) {
473-
memcpy(writePos + sizeof(TRowSize), data.data(), data.size());
500+
memcpy(writePos, data.data(), data.size());
474501
}
475502

476-
rows.emplace_back(batchBuffer->as<char>() + bufferUsed, rowSize);
477-
bufferUsed += rowSize;
503+
LOG(INFO) << "LocalShuffleReader::nextSorted wrote row at offset " << bufferUsed
504+
<< " size=" << data.size() << " (no prefix, like CoscoShuffle)";
505+
506+
rows.emplace_back(batchBuffer->as<char>() + bufferUsed, data.size());
507+
bufferUsed += data.size();
478508

479509
reader->next();
480510
}
481511

482512
if (!rows.empty()) {
513+
LOG(INFO) << "LocalShuffleReader::nextSorted final batch with " << rows.size() << " rows";
483514
batches.push_back(
484515
std::make_unique<ReadBatch>(std::move(rows), std::move(batchBuffer)));
485516
}
@@ -528,11 +559,19 @@ std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextUnsorted(
528559

529560
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
530561
LocalShuffleReader::next(uint64_t maxBytes) {
562+
LOG(INFO) << "LocalShuffleReader::next START maxBytes=" << maxBytes
563+
<< " sortedShuffle=" << sortedShuffle_;
531564
VELOX_CHECK(
532565
initialized_,
533566
"LocalShuffleReader::initialize() must be called before next()");
534-
return folly::makeSemiFuture(
535-
sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes));
567+
568+
// Testvalue injection point for exception testing
569+
velox::common::testutil::TestValue::adjust(
570+
"facebook::presto::operators::LocalShuffleReader::next", this);
571+
572+
auto result = sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes);
573+
LOG(INFO) << "LocalShuffleReader::next COMPLETE returned " << result.size() << " batches";
574+
return folly::makeSemiFuture(std::move(result));
536575
}
537576

538577
void LocalShuffleReader::noMoreData(bool success) {
@@ -576,42 +615,64 @@ std::shared_ptr<ShuffleReader> LocalPersistentShuffleFactory::createReader(
576615
const std::string& serializedStr,
577616
const int32_t /*partition*/,
578617
velox::memory::MemoryPool* pool) {
618+
LOG(INFO) << "LocalPersistentShuffleFactory::createReader START";
579619
const operators::LocalShuffleReadInfo readInfo =
580620
operators::LocalShuffleReadInfo::deserialize(serializedStr);
621+
// Check if sortedShuffle field is present in the JSON
622+
bool sortedShuffle = false;
623+
try {
624+
const auto jsonReadInfo = json::parse(serializedStr);
625+
if (jsonReadInfo.contains("sortedShuffle")) {
626+
jsonReadInfo.at("sortedShuffle").get_to(sortedShuffle);
627+
}
628+
} catch (const std::exception& /*e*/) {
629+
// If parsing fails or field doesn't exist, default to false
630+
sortedShuffle = false;
631+
}
632+
LOG(INFO) << "LocalPersistentShuffleFactory::createReader sortedShuffle=" << sortedShuffle
633+
<< " rootPath=" << readInfo.rootPath;
581634
auto reader = std::make_shared<LocalShuffleReader>(
582635
readInfo.rootPath,
583636
readInfo.queryId,
584637
readInfo.partitionIds,
585-
/*sortShuffle=*/false, // default to false for now
638+
sortedShuffle,
586639
pool);
587640
reader->initialize();
641+
LOG(INFO) << "LocalPersistentShuffleFactory::createReader COMPLETE";
588642
return reader;
589643
}
590644

591645
std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
592646
const std::string& serializedStr,
593647
velox::memory::MemoryPool* pool) {
648+
LOG(INFO) << "LocalPersistentShuffleFactory::createWriter START";
594649
static const uint64_t maxBytesPerPartition =
595650
SystemConfig::instance()->localShuffleMaxPartitionBytes();
596651
const operators::LocalShuffleWriteInfo writeInfo =
597652
operators::LocalShuffleWriteInfo::deserialize(serializedStr);
598-
return std::make_shared<LocalShuffleWriter>(
653+
// Check if sortedShuffle field is present in the JSON
654+
bool sortedShuffle = false;
655+
try {
656+
const auto jsonWriteInfo = json::parse(serializedStr);
657+
if (jsonWriteInfo.contains("sortedShuffle")) {
658+
jsonWriteInfo.at("sortedShuffle").get_to(sortedShuffle);
659+
}
660+
} catch (const std::exception& /*e*/) {
661+
// If parsing fails or field doesn't exist, default to false
662+
sortedShuffle = false;
663+
}
664+
LOG(INFO) << "LocalPersistentShuffleFactory::createWriter sortedShuffle=" << sortedShuffle
665+
<< " numPartitions=" << writeInfo.numPartitions
666+
<< " rootPath=" << writeInfo.rootPath;
667+
auto writer = std::make_shared<LocalShuffleWriter>(
599668
writeInfo.rootPath,
600669
writeInfo.queryId,
601670
writeInfo.shuffleId,
602671
writeInfo.numPartitions,
603672
maxBytesPerPartition,
604-
/*sortedShuffle=*/false, // default to false for now
673+
sortedShuffle,
605674
pool);
675+
LOG(INFO) << "LocalPersistentShuffleFactory::createWriter COMPLETE";
676+
return writer;
606677
}
607-
608-
// Testing function to expose extractRowMetadata for tests.
609-
// This will be removed after reader changes.
610-
std::vector<RowMetadata> testingExtractRowMetadata(
611-
const char* buffer,
612-
size_t bufferSize,
613-
bool sortedShuffle) {
614-
return extractRowMetadata(buffer, bufferSize, sortedShuffle);
615-
}
616-
617678
} // namespace facebook::presto::operators

presto-native-execution/presto_cpp/main/operators/LocalShuffle.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ inline std::strong_ordering compareKeys(
5353
reinterpret_cast<const uint8_t*>(key2.data() + key2.size()));
5454
}
5555

56-
// Testing function to expose extractRowMetadata for tests.
57-
std::vector<RowMetadata> testingExtractRowMetadata(
58-
const char* buffer,
59-
size_t bufferSize,
60-
bool sortedShuffle);
61-
6256
// LocalShuffleWriteInfo is used for containing shuffle write information.
6357
// This struct is a 1:1 strict API mapping to
6458
// presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkLocalShuffleWriteInfo.java
@@ -75,7 +69,6 @@ struct LocalShuffleWriteInfo {
7569
/// Structures are assumed to be encoded in JSON format.
7670
static LocalShuffleWriteInfo deserialize(const std::string& info);
7771
};
78-
7972
// LocalShuffleReadInfo is used for containing shuffle read metadata
8073
// This struct is a 1:1 strict API mapping to
8174
// presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkLocalShuffleReadInfo.java.
@@ -87,8 +80,6 @@ struct LocalShuffleReadInfo {
8780
std::string queryId;
8881
std::vector<std::string> partitionIds;
8982

90-
/// Deserializes shuffle information that is used by LocalPersistentShuffle.
91-
/// Structures are assumed to be encoded in JSON format.
9283
static LocalShuffleReadInfo deserialize(const std::string& info);
9384
};
9485

presto-native-execution/presto_cpp/main/operators/ShuffleRead.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ RowVectorPtr ShuffleRead::getOutput() {
8686
auto* batch = checked_pointer_cast<ShuffleRowBatch>(page.get());
8787
const auto& rows = batch->rows();
8888
for (const auto& row : rows) {
89+
LOG(INFO) << "ShuffleRead::getOutput adding row to rows_: size=" << row.size();
90+
if (row.size() >= 16) {
91+
LOG(INFO) << "ShuffleRead::getOutput row data (hex): "
92+
<< folly::hexlify(std::string_view(row.data(), std::min(row.size(), size_t(32))));
93+
}
8994
rows_.emplace_back(row);
9095
}
9196
}

0 commit comments

Comments
 (0)