Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .github/workflows/integration_test_8gpu_features.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ jobs:
sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded"
sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded"

python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8

# Verify the accuracy.
echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity"
# Verify the accuracy first.
echo "Checking FSDP8 v.s. HSDP (4, 2) accuracy parity"
export baseline_options="--parallelism.data_parallel_replicate_degree=1"
export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2"
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1
export test_options="--parallelism.data_parallel_replicate_degree=4"
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result tests/assets/losses/llama3.txt
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because you dump something to the folder, so that later runs will complain it's not empty? I think we should make this dump folder optional so that the first run doesn't use it at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why we need artifacts-to-be-uploaded is that torchtitan will output something to the output folder and the default is outputs. But creating outputs will fail because the file system is read-only. So, basically if we want to run a TorchTitan job, we will need to redirect the outputs to artifacts-to-be-uploaded.

I feel it is too much to make outputs to be optional because there will be several checks in the trainer. And all these are just for CI. I would rather say the integration tests shouldn't expect the folder to be empty.


python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8

# Cleanup the checkpoints so that we don't waste network bandwidth and time.
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint
Expand Down
149 changes: 144 additions & 5 deletions scripts/loss_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def validate_arguments(
test_train_file: str,
test_options: str,
steps: int,
assert_equal: bool,
export_result: str | None,
import_result: str | None,
) -> None:
"""Validate command line arguments."""
# Validate commit arguments - if one is ".", both must be "."
Expand Down Expand Up @@ -201,6 +204,34 @@ def validate_arguments(
log_print(f"Error: --steps must be a positive integer, got: {steps}")
sys.exit(1)

# Validate export-result requires assert-equal
if export_result and not assert_equal:
log_print("Error: --export-result requires --assert-equal")
log_print(" Export only happens when losses are verified to match")
sys.exit(1)

# Validate import-result requires assert-equal
if import_result and not assert_equal:
log_print("Error: --import-result requires --assert-equal")
log_print(" Import is used to verify all losses match")
sys.exit(1)

# Validate export-result and import-result are mutually exclusive
if export_result and import_result:
log_print(
"Error: --export-result and --import-result cannot be " "used together"
)
log_print(
" Use export to save results or import to compare "
"against saved results"
)
sys.exit(1)

# Validate import file exists
if import_result and not os.path.exists(import_result):
log_print(f"Error: Import file does not exist: {import_result}")
sys.exit(1)


# =============================================================================
# SETUP FUNCTIONS
Expand Down Expand Up @@ -433,6 +464,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]:
return losses


def export_losses_to_file(losses: dict[int, float], export_path: str) -> None:
"""Export losses to file and stdout.

Args:
losses: Dictionary mapping step numbers to loss values
export_path: Path to export file
"""
log_print(f"Exporting losses to {export_path}")

# Write to file and collect output for stdout
with open(export_path, "w") as f:
for step in sorted(losses.keys()):
loss = losses[step]
line = f"{step} {loss}"
f.write(line + "\n")

log_print(f"Exported {len(losses)} loss values:")
log_print()

# Output to stdout in same format
for step in sorted(losses.keys()):
loss = losses[step]
print(f"{step} {loss}")

log_print()
log_print(f"Losses saved to: {export_path}")


def extract_loss_data(output_folder: str | None) -> None:
"""Extract loss data from logs."""
if not output_folder:
Expand Down Expand Up @@ -556,13 +615,18 @@ def perform_loss_analysis(
generate_summary_statistics(baseline_losses, test_losses, stats_file)


def assert_losses_equal(baseline_log: str, test_log: str) -> None:
"""Assert that losses are equal between baseline and test using
unittest.
def assert_losses_equal(
baseline_log: str, test_log: str, import_result: str | None = None
) -> None:
"""Assert that losses are equal between baseline and test using unittest.

If import_result is provided, also compares baseline with imported losses.
"""
log_print("Asserting losses are equal...")
log_print(f"Baseline log: {baseline_log}")
log_print(f"Test log: {test_log}")
if import_result:
log_print(f"Import file: {import_result}")

# Extract losses from both logs
baseline_losses = extract_losses_from_log(baseline_log)
Expand All @@ -579,6 +643,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None:
log_print("Error: No losses found in test log")
sys.exit(1)

# Load imported losses if provided
imported_losses = None
if import_result:
imported_losses = read_losses_from_file(import_result)
log_print(f"Loaded {len(imported_losses)} steps from import file")
if not imported_losses:
log_print("Error: No losses found in import file")
sys.exit(1)

# Create a test case
class LossEqualityTest(unittest.TestCase):
def test_losses_equal(self):
Expand All @@ -593,17 +666,41 @@ def test_losses_equal(self):
f"test has {len(test_steps)} steps",
)

# If imported losses exist, check steps match
if imported_losses:
imported_steps = set(imported_losses.keys())
self.assertEqual(
baseline_steps,
imported_steps,
f"Steps mismatch: baseline has {len(baseline_steps)} steps, "
f"imported has {len(imported_steps)} steps",
)

# Check that losses are equal for each step
for step in sorted(baseline_steps):
baseline_loss = baseline_losses[step]
test_loss = test_losses[step]

# Compare baseline vs test
self.assertEqual(
baseline_loss,
test_loss,
f"Loss mismatch at step {step}: "
f"baseline={baseline_loss}, test={test_loss}",
)

# Compare baseline vs imported (if provided)
# No need to compare test vs imported since:
# baseline==test and baseline==imported implies test==imported
if imported_losses:
imported_loss = imported_losses[step]
self.assertEqual(
baseline_loss,
imported_loss,
f"Loss mismatch at step {step}: "
f"baseline={baseline_loss}, imported={imported_loss}",
)

# Run the test
suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest)
runner = unittest.TextTestRunner(verbosity=2)
Expand All @@ -613,7 +710,13 @@ def test_losses_equal(self):
log_print("Loss assertion failed!")
sys.exit(1)
else:
log_print("All losses are equal. Assertion passed!")
if import_result:
log_print(
"All losses are equal (baseline, test, and imported). "
"Assertion passed!"
)
else:
log_print("All losses are equal. Assertion passed!")


def cleanup_temp_files(output_folder: str | None) -> None:
Expand Down Expand Up @@ -756,6 +859,24 @@ def parse_arguments() -> argparse.Namespace:
"Script exits with error if losses differ."
),
)
parser.add_argument(
"--export-result",
default="",
help=(
"Export losses to specified file path (requires --assert-equal). "
"Exports only when losses match. Format: '{step} {loss}' per line."
),
)
parser.add_argument(
"--import-result",
default="",
help=(
"Import losses from specified file path for comparison "
"(requires --assert-equal). "
"Compares imported losses with both baseline and test "
"(all 3 must match)."
),
)
parser.add_argument(
"--job-dump-folder",
default="outputs",
Expand Down Expand Up @@ -787,6 +908,14 @@ def parse_arguments() -> argparse.Namespace:
if not args.output_folder:
args.output_folder = None

# Convert empty export_result to None
if not args.export_result:
args.export_result = None

# Convert empty import_result to None
if not args.import_result:
args.import_result = None

return args


Expand Down Expand Up @@ -850,6 +979,9 @@ def main() -> None:
args.test_train_file,
args.test_options,
args.steps,
args.assert_equal,
args.export_result,
args.import_result,
)

# Setup environment
Expand Down Expand Up @@ -912,7 +1044,14 @@ def main() -> None:

# Assert losses are equal if requested
if args.assert_equal:
assert_losses_equal(baseline_log, test_log)
# Pass import_result if provided for 3-way comparison
assert_losses_equal(baseline_log, test_log, args.import_result)

# Export losses if requested (only after assertion passes)
if args.export_result:
# Extract baseline losses (they equal test losses since assertion passed)
baseline_losses = extract_losses_from_log(baseline_log)
export_losses_to_file(baseline_losses, args.export_result)

# Analysis and reporting
perform_loss_analysis(baseline_log, test_log, stats_file)
Expand Down
10 changes: 10 additions & 0 deletions tests/assets/losses/llama3.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put in https://github.com/pytorch/torchtitan/tree/main/tests/assets and just call it llama3_losses.txt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q:Is this ground truth loss come from a single GPU run? Or is it FSDP only?

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
1 8.1376
2 7.841
3 7.1815
4 6.3509
5 5.5272
6 4.9244
7 4.5606
8 4.3724
9 4.347
10 4.2004
Loading