From dc6b81378e68debc16e7ccbfeb7f00b48e8ed4e6 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 24 Aug 2025 00:01:21 -0700 Subject: [PATCH 01/15] some initial code --- image_processing/Dockerfile | 10 +++ image_processing/README.md | 9 +++ image_processing/job.yaml | 50 +++++++++++++ image_processing/process_images.py | 112 +++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 image_processing/Dockerfile create mode 100644 image_processing/README.md create mode 100644 image_processing/job.yaml create mode 100644 image_processing/process_images.py diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile new file mode 100644 index 0000000..1ae25ae --- /dev/null +++ b/image_processing/Dockerfile @@ -0,0 +1,10 @@ +FROM anyscale/ray:2.48.0-slim-py312-cu128 + +# C compiler for Triton’s runtime build step (vLLM V1 engine) +# https://github.com/vllm-project/vllm/issues/2997 +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends build-essential + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +RUN uv pip install --system huggingface_hub diff --git a/image_processing/README.md b/image_processing/README.md new file mode 100644 index 0000000..5ea8aca --- /dev/null +++ b/image_processing/README.md @@ -0,0 +1,9 @@ +# Process images + +This example uses Ray Data to process the [ReLAION-2B](https://huggingface.co/datasets/laion/relaion2B-en-research-safe) image dataset, which consists of over 2 billion rows. Each row consists of an image URL along with various metadata include a caption and image dimensions. + +## Install the Anyscale CLI + +``` +anyscale job submit -f job.yaml --env HF_TOKEN=$HF_TOKEN +``` \ No newline at end of file diff --git a/image_processing/job.yaml b/image_processing/job.yaml new file mode 100644 index 0000000..7663989 --- /dev/null +++ b/image_processing/job.yaml @@ -0,0 +1,50 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. + +name: process-images + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: +# head_node: +# instance_type: m5.2xlarge +# worker_nodes: +# - instance_type: m5.16xlarge +# min_nodes: 0 +# max_nodes: 100 +# - instance_type: m7a.24xlarge +# min_nodes: 0 +# max_nodes: 100 +# market_type: PREFER_SPOT # Defaults to ON_DEMAND +# - instance_type: g4dn.2xlarge +# min_nodes: 0 +# max_nodes: 100 +# market_type: PREFER_SPOT # Defaults to ON_DEMAND +# min_resources: +# CPU: 100 +# GPU: 1 +# max_resources: +# CPU: 5000 +# GPU: 100 + auto_select_worker_config: true + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: python process_images.py + +# If there is an error, do not retry. +max_retries: 0 \ No newline at end of file diff --git a/image_processing/process_images.py b/image_processing/process_images.py new file mode 100644 index 0000000..862311f --- /dev/null +++ b/image_processing/process_images.py @@ -0,0 +1,112 @@ +import concurrent.futures +import os +import ray +import requests +from huggingface_hub import HfFileSystem + + +def fetch_image_original(row): + try: + response = requests.get(row["url"], timeout=5) + if response.status_code == 200: + row["image_bytes"] = response.content + row["success"] = True + row["error"] = None + else: + row["image_bytes"] = None + row["success"] = False + row["error"] = f"Status code: {response.status_code}" + return row + except Exception as e: + row["image_bytes"] = None + row["success"] = False + row["error"] = str(e) + return row + + +def fetch_image(url): + try: + response = requests.get(url, timeout=5) + if response.status_code == 200: + image_bytes = response.content + success = True + error = None + else: + image_bytes = None + success = False + error = f"Status code: {response.status_code}" + except Exception as e: + image_bytes = None + success = False + error = str(e) + + return image_bytes, success, error + + +def fetch_images_batch_threaded(batch): + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + results = list(executor.map(fetch_image, batch["url"])) + batch["image_bytes"] = [result[0] for result in results] + batch["success"] = [result[1] for result in results] + batch["error"] = [result[2] for result in results] + return batch + + +dataset = ray.data.read_parquet( + "hf://datasets/laion/relaion2B-en-research-safe/", + file_extensions=["parquet"], + filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), + ray_remote_args={"memory": 10*10**9} +) +dataset = dataset.repartition(target_num_rows_per_block=5000) +dataset = dataset.map_batches( + fetch_images_batch_threaded, + batch_size=1000, + # ray_remote_args_fn=lambda: { + # "memory": 2 * 10**9 + # }, +) +dataset = dataset.materialize() + +""" +urls = [ + 'https://lid.zoocdn.com/354/255/1c262cde0e91356edf2f1ddad3e92ae8ccf9dd98.jpg', + 'http://tse2.mm.bing.net/th?id=OIP.Gqs8tMg9NcakYZVveTRSfQEsEs', + 'https://shop.foleyfoodandwinesociety.com/assets/images/products/pictures/21021-560.png', + 'https://d1ea30dbll17d8.cloudfront.net/12/1/images/catalog/i/xl_137503-P7010009.jpg', + 'http://rlv.zcache.com/pomeranian_christmas_card_santa_and_bears-r500df5b7d0f94c258ca220c7110273bf_xvuak_8byvr_152.jpg', + 'https://reisetopia.de/wp-content/uploads/2018/04/Swiss-First-Class-Internet-Voucher-1024x576.jpg', + 'http://tse3.mm.bing.net/th?id=OIP.g1eGMQVWD1zBOqlSW6tG9wHaHf', + 'https://tap2.fkimg.com/media/vr-splice-j/01/41/0b/04.jpg', + 'https://i.pinimg.com/736x/3f/e5/32/3fe532c33b9babe6ed6df71a137891d0.jpg', + 'https://s3-us-west-2.amazonaws.com/tabs.web.media/c/7/c7vf/c7vf-square-175.jpg', + 'https://secure.img.wfrcdn.com/lf/43/hash/6189/7540185/1/26%2BBottle%2BSingle%2BZone%2BBuilt-In%2BWine%2BRefrigerator.jpg', + 'https://storage.googleapis.com/idx-photos-gs.ihouseprd.com/OR-RMLS/17075635/1x/000.jpg', + 'http://cdn-w.v12soft.com/photos/8PtebOg/9122296/fszBM_800600.jpg', + 'https://blogs.gnome.org/aklapper/files/2011/03/sponsored-gnome-badge-shadow.png', + 'https://laclothing.co.uk/wp-content/uploads/prod_dp37_88695-400x533.jpg', + 'https://onokovtsy.diamondelectric.ru/images/2957/2956669/small_kofemashina_philips_ep503510_lattego_series_5000_1.jpg', + 'https://www.enigmasoftware.com/wp-content/themes/default/images/pages/download/spyhunter/chrome/2.jpg', + 'https://media.vanityfair.com/photos/54cab13f674871890b5748f9/master/w_320%2Cc_limit/image.jpg', + 'http://ih1.redbubble.net/image.16549013.9491/flat,220x200,075,t.jpg', + 'https://www.cstatic-images.com/stock/400x300/265593.jpg', + 'https://image.spreadshirtmedia.net/image-server/v1/products/T1437A737PA4399PT17X218Y32D174571818FS798/views/1,width=500,height=500,appearanceId=737/diagramme-t-shirt-manches-longues-henley.jpg', + 'https://www.picclickimg.com/d/l400/pict/371869372230_/South-Carolina-Camper.jpg', + 'https://www.amysbakehouse.com.au/wp-content/uploads/2015/08/30thBirthdayCake5-300x300.jpg', + 'https://images-eu.ssl-images-amazon.com/images/I/51uUj7YpvAL._AC_UL160_SR102,160_.jpg', + 'https://i.ytimg.com/vi/Bas_HPoslzY/mqdefault.jpg', + 'https://aviewfrommyseat.com/medium/anonymous-20150912164947.jpg', + 'https://auto.cdn-rivamedia.com/photos/annonce/bigcriteo/mercedes-classe-c-180-d-122ch-amg-line-9g-tronic-117174350.jpg', + 'https://ae01.alicdn.com/kf/HTB1BSAEjYSYBuNjSspiq6xNzpXaS/Custom-photo-3d-wallpaper-cloth-Motorcycle-retro-nostalgic-living-room-Home-decor-3d-wall-murals-wallpaper.jpg', + 'http://booklikes.com/photo/max/220/330/upload/books/b/5/b50d9f8dd976cb135363e03ee6f279fd.jpg', + 'https://images.snapwi.re/d571/5a675db017312ea0328b456f.w800.jpg', + 'https://m.smedata.sk/api-media/media/image/sme/0/43/4366330/4366330_600x400.jpeg?rev=3', + 'https://d31l02nbp0owar.cloudfront.net/m/t/15589/15579430/a-0005.jpg', + 'https://img.shopstyle-cdn.com/pim/b2/c7/b2c7ad1e66c19e0aef0fad8464580abd_xlarge.jpg', + 'https://www.fineartstorehouse.com/t/629/greater-yellowlegs-reflects-13091205.jpg.webp', + 'https://media.ticmate.com/resources/ticmate_live/upload/leeds_united_team_logo.jpg', + 'http://d3d71ba2asa5oz.cloudfront.net/62000804/images/by0006_2.jpg', + 'https://fscomps.fotosearch.com/compc/CSP/CSP990/seamless-light-oak-square-parquet-panel-clipart__k10885495.jpg', + 'https://i.pinimg.com/236x/29/86/6b/29866b98be3977e1f8f7c58c27a1607d.jpg' +] +""" \ No newline at end of file From b4971057606dfbacffd359fe6845d21f8105a76d Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 24 Aug 2025 23:43:02 -0700 Subject: [PATCH 02/15] Add vllm --- image_processing/Dockerfile | 4 + image_processing/process_images.py | 188 +++++++++++++++++------------ 2 files changed, 115 insertions(+), 77 deletions(-) diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile index 1ae25ae..01eff96 100644 --- a/image_processing/Dockerfile +++ b/image_processing/Dockerfile @@ -8,3 +8,7 @@ RUN sudo apt-get update && \ RUN curl -LsSf https://astral.sh/uv/install.sh | sh RUN uv pip install --system huggingface_hub + +RUN uv pip install --system vllm==0.9.2 +# Avoid https://github.com/vllm-project/vllm-ascend/issues/2046 with transformers < 4.54.0 +RUN uv pip install --system transformers==4.53.3 \ No newline at end of file diff --git a/image_processing/process_images.py b/image_processing/process_images.py index 862311f..a6987ce 100644 --- a/image_processing/process_images.py +++ b/image_processing/process_images.py @@ -3,34 +3,32 @@ import ray import requests from huggingface_hub import HfFileSystem +from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor +from PIL import Image +from io import BytesIO -def fetch_image_original(row): - try: - response = requests.get(row["url"], timeout=5) - if response.status_code == 200: - row["image_bytes"] = response.content - row["success"] = True - row["error"] = None - else: - row["image_bytes"] = None - row["success"] = False - row["error"] = f"Status code: {response.status_code}" - return row - except Exception as e: - row["image_bytes"] = None - row["success"] = False - row["error"] = str(e) - return row +num_images = 1000000 +num_model_replicas = 64 +tensor_parallelism = 4 + +output_path = os.path.join(os.environ["ANYSCALE_ARTIFACT_STORAGE"], "rkn/process_images_output") def fetch_image(url): try: response = requests.get(url, timeout=5) if response.status_code == 200: - image_bytes = response.content - success = True - error = None + ctype = response.headers.get("Content-Type", "") + if ctype.startswith("image"): + image_bytes = response.content + Image.open(BytesIO(image_bytes)) # Validate the image formatting. + success = True + error = None + else: + image_bytes = None + success = False + error = f"Content-Type is not an image: {ctype}" else: image_bytes = None success = False @@ -43,8 +41,14 @@ def fetch_image(url): return image_bytes, success, error +def convert_to_pil_image(row): + row["pil_image"] = Image.open(BytesIO(row["image_bytes"])) + return row + + def fetch_images_batch_threaded(batch): - with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + # Previously used 50 instead of 250 + with concurrent.futures.ThreadPoolExecutor(max_workers=250) as executor: results = list(executor.map(fetch_image, batch["url"])) batch["image_bytes"] = [result[0] for result in results] batch["success"] = [result[1] for result in results] @@ -52,61 +56,91 @@ def fetch_images_batch_threaded(batch): return batch -dataset = ray.data.read_parquet( - "hf://datasets/laion/relaion2B-en-research-safe/", - file_extensions=["parquet"], - filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), - ray_remote_args={"memory": 10*10**9} +vision_processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + tensor_parallel_size=tensor_parallelism, + pipeline_parallel_size=1, + max_model_len=32768, + enable_chunked_prefill=True, + max_num_batched_tokens=2048, + ), + # Override Ray's runtime env to include the Hugging Face token. Ray Data uses Ray under the hood to orchestrate the inference pipeline. + runtime_env=dict( + env_vars=dict( + VLLM_USE_V1="1", + ), + ), + batch_size=16, + accelerator_type="A10G", + concurrency=num_model_replicas, + has_image=True, ) -dataset = dataset.repartition(target_num_rows_per_block=5000) -dataset = dataset.map_batches( - fetch_images_batch_threaded, - batch_size=1000, - # ray_remote_args_fn=lambda: { - # "memory": 2 * 10**9 - # }, + + +def vision_preprocess(row: dict) -> dict: + return dict( + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + # Ray Data accepts PIL Image or image URL. + # "image": row["pil_image"], + "image": Image.open(BytesIO(row["image_bytes"])) + # "image": row["image_bytes"], + }, + ] + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=150, + detokenize=False, + ), + ) + + +def vision_postprocess(row: dict) -> dict: + return row + + +vision_processor = build_llm_processor( + vision_processor_config, + preprocess=vision_preprocess, + postprocess=vision_postprocess, ) -dataset = dataset.materialize() - -""" -urls = [ - 'https://lid.zoocdn.com/354/255/1c262cde0e91356edf2f1ddad3e92ae8ccf9dd98.jpg', - 'http://tse2.mm.bing.net/th?id=OIP.Gqs8tMg9NcakYZVveTRSfQEsEs', - 'https://shop.foleyfoodandwinesociety.com/assets/images/products/pictures/21021-560.png', - 'https://d1ea30dbll17d8.cloudfront.net/12/1/images/catalog/i/xl_137503-P7010009.jpg', - 'http://rlv.zcache.com/pomeranian_christmas_card_santa_and_bears-r500df5b7d0f94c258ca220c7110273bf_xvuak_8byvr_152.jpg', - 'https://reisetopia.de/wp-content/uploads/2018/04/Swiss-First-Class-Internet-Voucher-1024x576.jpg', - 'http://tse3.mm.bing.net/th?id=OIP.g1eGMQVWD1zBOqlSW6tG9wHaHf', - 'https://tap2.fkimg.com/media/vr-splice-j/01/41/0b/04.jpg', - 'https://i.pinimg.com/736x/3f/e5/32/3fe532c33b9babe6ed6df71a137891d0.jpg', - 'https://s3-us-west-2.amazonaws.com/tabs.web.media/c/7/c7vf/c7vf-square-175.jpg', - 'https://secure.img.wfrcdn.com/lf/43/hash/6189/7540185/1/26%2BBottle%2BSingle%2BZone%2BBuilt-In%2BWine%2BRefrigerator.jpg', - 'https://storage.googleapis.com/idx-photos-gs.ihouseprd.com/OR-RMLS/17075635/1x/000.jpg', - 'http://cdn-w.v12soft.com/photos/8PtebOg/9122296/fszBM_800600.jpg', - 'https://blogs.gnome.org/aklapper/files/2011/03/sponsored-gnome-badge-shadow.png', - 'https://laclothing.co.uk/wp-content/uploads/prod_dp37_88695-400x533.jpg', - 'https://onokovtsy.diamondelectric.ru/images/2957/2956669/small_kofemashina_philips_ep503510_lattego_series_5000_1.jpg', - 'https://www.enigmasoftware.com/wp-content/themes/default/images/pages/download/spyhunter/chrome/2.jpg', - 'https://media.vanityfair.com/photos/54cab13f674871890b5748f9/master/w_320%2Cc_limit/image.jpg', - 'http://ih1.redbubble.net/image.16549013.9491/flat,220x200,075,t.jpg', - 'https://www.cstatic-images.com/stock/400x300/265593.jpg', - 'https://image.spreadshirtmedia.net/image-server/v1/products/T1437A737PA4399PT17X218Y32D174571818FS798/views/1,width=500,height=500,appearanceId=737/diagramme-t-shirt-manches-longues-henley.jpg', - 'https://www.picclickimg.com/d/l400/pict/371869372230_/South-Carolina-Camper.jpg', - 'https://www.amysbakehouse.com.au/wp-content/uploads/2015/08/30thBirthdayCake5-300x300.jpg', - 'https://images-eu.ssl-images-amazon.com/images/I/51uUj7YpvAL._AC_UL160_SR102,160_.jpg', - 'https://i.ytimg.com/vi/Bas_HPoslzY/mqdefault.jpg', - 'https://aviewfrommyseat.com/medium/anonymous-20150912164947.jpg', - 'https://auto.cdn-rivamedia.com/photos/annonce/bigcriteo/mercedes-classe-c-180-d-122ch-amg-line-9g-tronic-117174350.jpg', - 'https://ae01.alicdn.com/kf/HTB1BSAEjYSYBuNjSspiq6xNzpXaS/Custom-photo-3d-wallpaper-cloth-Motorcycle-retro-nostalgic-living-room-Home-decor-3d-wall-murals-wallpaper.jpg', - 'http://booklikes.com/photo/max/220/330/upload/books/b/5/b50d9f8dd976cb135363e03ee6f279fd.jpg', - 'https://images.snapwi.re/d571/5a675db017312ea0328b456f.w800.jpg', - 'https://m.smedata.sk/api-media/media/image/sme/0/43/4366330/4366330_600x400.jpeg?rev=3', - 'https://d31l02nbp0owar.cloudfront.net/m/t/15589/15579430/a-0005.jpg', - 'https://img.shopstyle-cdn.com/pim/b2/c7/b2c7ad1e66c19e0aef0fad8464580abd_xlarge.jpg', - 'https://www.fineartstorehouse.com/t/629/greater-yellowlegs-reflects-13091205.jpg.webp', - 'https://media.ticmate.com/resources/ticmate_live/upload/leeds_united_team_logo.jpg', - 'http://d3d71ba2asa5oz.cloudfront.net/62000804/images/by0006_2.jpg', - 'https://fscomps.fotosearch.com/compc/CSP/CSP990/seamless-light-oak-square-parquet-panel-clipart__k10885495.jpg', - 'https://i.pinimg.com/236x/29/86/6b/29866b98be3977e1f8f7c58c27a1607d.jpg' -] -""" \ No newline at end of file + + +ray.data.DataContext.get_current().retried_io_errors.extend( + [ + "Temporary failure in name resolution", + "Max retries exceeded with url", + "Failed to establish a new connection", + "HTTPSConnectionPool", + ] +) + + +dataset = ray.data \ + .read_parquet( + "hf://datasets/laion/relaion2B-en-research-safe/", + file_extensions=["parquet"], + filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), + ray_remote_args={"memory": 10*10**9}, + ) \ + .limit(num_images) \ + .repartition(target_num_rows_per_block=5000) \ + .map_batches( + fetch_images_batch_threaded, + batch_size=1000, + # ray_remote_args_fn=lambda: { + # "memory": 2 * 10**9 + # }, + ) \ + .filter(lambda row: row["success"]) \ + .filter(lambda row: Image.open(BytesIO(row["image_bytes"])).format == "JPEG") + +dataset = vision_processor(dataset) +dataset.write_parquet(output_path) \ No newline at end of file From 2e7afe82a7e1df51ef6df186fec210dc6cbddd22 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Mon, 3 Nov 2025 11:27:12 -0800 Subject: [PATCH 03/15] Scale image processing pipeline for 2B+ images - Update Ray base image to 2.51.1 and vLLM to 0.11.0 - Add boto3 dependency for S3 operations - Update transformers to 4.57.1 for compatibility - Configure compute resources with auto-selection (max 520 CPU, 128 GPU) - Add disk size configuration options for customer-hosted deployments - Implement robust URL validation and error handling - Add base64 image encoding for Arrow serialization - Add JPEG format validation and 128x128 image resizing - Scale model replicas from 1 to 32 for higher throughput - Optimize batch sizes and memory usage for large-scale processing - Implement session pooling for HTTP requests with retry logic - Add timestamp-based output paths to /mnt/shared_storage - Add run.sh script for job submission with HF_TOKEN --- image_processing/Dockerfile | 10 +- image_processing/job.yaml | 52 +++-- image_processing/process_images.py | 313 ++++++++++++++++++++++++----- image_processing/run.sh | 2 + 4 files changed, 303 insertions(+), 74 deletions(-) create mode 100755 image_processing/run.sh diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile index 01eff96..d48065a 100644 --- a/image_processing/Dockerfile +++ b/image_processing/Dockerfile @@ -1,4 +1,4 @@ -FROM anyscale/ray:2.48.0-slim-py312-cu128 +FROM anyscale/ray:2.51.1-slim-py312-cu128 # C compiler for Triton’s runtime build step (vLLM V1 engine) # https://github.com/vllm-project/vllm/issues/2997 @@ -7,8 +7,8 @@ RUN sudo apt-get update && \ RUN curl -LsSf https://astral.sh/uv/install.sh | sh -RUN uv pip install --system huggingface_hub +RUN uv pip install --system huggingface_hub boto3 -RUN uv pip install --system vllm==0.9.2 -# Avoid https://github.com/vllm-project/vllm-ascend/issues/2046 with transformers < 4.54.0 -RUN uv pip install --system transformers==4.53.3 \ No newline at end of file +RUN uv pip install --system vllm==0.11.0 + +RUN uv pip install --system transformers==4.57.1 \ No newline at end of file diff --git a/image_processing/job.yaml b/image_processing/job.yaml index 7663989..cdd32c0 100644 --- a/image_processing/job.yaml +++ b/image_processing/job.yaml @@ -12,28 +12,40 @@ containerfile: ./Dockerfile # When empty, Anyscale will auto-select the instance types. You can also specify # minimum and maximum resources. compute_config: -# head_node: -# instance_type: m5.2xlarge -# worker_nodes: -# - instance_type: m5.16xlarge -# min_nodes: 0 -# max_nodes: 100 -# - instance_type: m7a.24xlarge -# min_nodes: 0 -# max_nodes: 100 -# market_type: PREFER_SPOT # Defaults to ON_DEMAND -# - instance_type: g4dn.2xlarge -# min_nodes: 0 -# max_nodes: 100 -# market_type: PREFER_SPOT # Defaults to ON_DEMAND -# min_resources: -# CPU: 100 -# GPU: 1 -# max_resources: -# CPU: 5000 -# GPU: 100 + # OPTION 1: Auto-selection (current - works on Anyscale-hosted) + # Uses default disk sizes (~100GB). Cannot customize disk with auto-selection. + min_resources: + CPU: 0 + GPU: 0 + max_resources: + CPU: 520 + GPU: 128 auto_select_worker_config: true + # OPTION 2: Explicit config with custom disk (CUSTOMER-HOSTED ONLY) + # Uncomment below and comment out the auto-selection config above to use custom disk. + # NOTE: advanced_instance_config only works on customer-hosted AWS accounts. + # See DISK_SIZE_OPTIONS.md for details. + # + # head_node: + # instance_type: m5.2xlarge + # advanced_instance_config: + # BlockDeviceMappings: + # - DeviceName: /dev/sda1 + # Ebs: + # VolumeSize: 500 + # VolumeType: gp3 + # worker_nodes: + # - instance_type: m5.16xlarge + # min_nodes: 0 + # max_nodes: 100 + # advanced_instance_config: + # BlockDeviceMappings: + # - DeviceName: /dev/sda1 + # Ebs: + # VolumeSize: 500 + # VolumeType: gp3 + # Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that # will be the working directory for the job. The files in the directory will be # automatically uploaded to the job environment in Anyscale. diff --git a/image_processing/process_images.py b/image_processing/process_images.py index 3bb2ba8..00d02d3 100644 --- a/image_processing/process_images.py +++ b/image_processing/process_images.py @@ -1,3 +1,4 @@ +import base64 import concurrent.futures import os import ray @@ -6,52 +7,220 @@ from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor from PIL import Image from io import BytesIO +import pyarrow.fs as pafs +from requests.adapters import HTTPAdapter +import urllib3 +import logging +import warnings +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) -num_images = 100 -num_model_replicas = 1 + +# Disable SSL warnings since we're disabling verification for misconfigured image hosts +# Suppress urllib3 connection pool warnings (timeout, connection errors, etc.) +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +# ============================================================================ +# SCALABILITY CONFIGURATION FOR 2B+ IMAGES +# ============================================================================ +# num_images = 100 +num_model_replicas = 32 tensor_parallelism = 1 +max_concurrent_downloads = 10 # Reduced to minimize memory spikes (was 10) + +from datetime import datetime, timezone + +timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + +output_path = f"/mnt/shared_storage/process_images_output/{timestamp}" + + +def create_session(): + """ + Create a requests session for image downloads without automatic retries. + """ + session = requests.Session() + + adapter = HTTPAdapter( + pool_connections=50, + pool_maxsize=50, + # Keep connections alive longer + pool_block=False, + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + -output_path = os.path.join(os.environ["ANYSCALE_ARTIFACT_STORAGE"], "rkn/process_images_output") +def fetch_image(url, session=None): + """ + Fetch image with validation and error handling. + If the download or validation fails, the image is marked as invalid without retrying. + """ + # Validate URL format first + if not url or not isinstance(url, str): + return None, False, "Invalid URL: empty or not a string" + + # Parse URL to properly validate its structure + from urllib.parse import urlparse -def fetch_image(url): try: - response = requests.get(url, timeout=5) - if response.status_code == 200: - ctype = response.headers.get("Content-Type", "") - if ctype.startswith("image"): - image_bytes = response.content - Image.open(BytesIO(image_bytes)) # Validate the image formatting. - success = True - error = None - else: - image_bytes = None - success = False - error = f"Content-Type is not an image: {ctype}" - else: - image_bytes = None - success = False - error = f"Status code: {response.status_code}" + parsed = urlparse(url) + + # Check if URL has a valid scheme (http or https) + if parsed.scheme not in ("http", "https"): + return ( + None, + False, + f"Invalid URL", + ) + + # Check if URL has a valid hostname/netloc + # netloc is the network location (domain/IP), e.g., "example.com" or "192.168.1.1" + if not parsed.netloc or len(parsed.netloc.strip()) < 3: + return None, False, f"Invalid URL: missing or invalid hostname: {url[:100]}" + + # Check if netloc contains at least a dot (for domain) or is localhost/IP + # This catches URLs like "https://albumart/image.jpg" which have no valid domain + if "." not in parsed.netloc and parsed.netloc not in ("localhost", "127.0.0.1"): + # Allow if it looks like an IPv6 address (contains colons) + if ":" not in parsed.netloc: + return ( + None, + False, + f"Invalid URL: malformed hostname (missing domain): {url[:100]}", + ) + except Exception as e: - image_bytes = None - success = False - error = str(e) + return None, False, f"Invalid URL: failed to parse: {str(e)[:100]}" - return image_bytes, success, error + # Create session if not provided (will be reused within batch) + if session is None: + session = create_session() + try: + response = session.get( + url, + timeout=(10, 20), # (connect_timeout=30s, read_timeout=60s) + verify=False, # Disable SSL verification for broken certs + allow_redirects=True, # Follow redirects + stream=False, # Download entire response + ) + except Exception as e: + return None, False, f"Error: {str(e)[:10]}" -def convert_to_pil_image(row): - row["pil_image"] = Image.open(BytesIO(row["image_bytes"])) - return row + if response.status_code == 200: + ctype = response.headers.get("Content-Type", "") + if ctype.startswith("image"): + image_bytes = response.content + try: + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=UserWarning) + warnings.filterwarnings( + "error", category=Image.DecompressionBombWarning + ) + # First verify the image format + img = Image.open(BytesIO(image_bytes)) + img.verify() # This checks if file is broken + # After verify(), we need to reopen to actually load the image + img = Image.open(BytesIO(image_bytes)) + img.load() # Force full image loading to detect truncation + img.close() + except (OSError, IOError) as e: + # Catch truncated images and other IO errors + error_msg = str(e)[:100] + if "truncated" in error_msg.lower(): + return None, False, f"Truncated image: {error_msg}" + return None, False, f"Image IO error: {error_msg}" + except Exception as e: + return None, False, f"Image validation error: {str(e)[:100]}" + return image_bytes, True, None + return None, False, f"Content-Type is not an image: {ctype}" + + return None, False, f"Status code: {response.status_code}" + + +def is_jpeg_format(row): + """Memory-efficient JPEG format check without keeping Image object in memory.""" + try: + image_data = row.get("image_base64") + if image_data is None: + return False + if isinstance(image_data, str): + image_data = base64.b64decode(image_data) + with Image.open(BytesIO(image_data)) as img: + return img.format == "JPEG" + except: + return False + + +def resize_image(row): + """Resize image to 128x128 pixels and standardize RGB values.""" + try: + image_data = row.get("image_base64") + if image_data is None: + return row + + # Decode base64 string to bytes + if isinstance(image_data, str): + image_bytes = base64.b64decode(image_data) + else: + image_bytes = image_data + + # Open image, convert to RGB, resize, and save back + with Image.open(BytesIO(image_bytes)) as img: + # Convert to RGB mode to ensure consistent 3-channel format + # This handles CMYK, grayscale, RGBA, etc. + if img.mode != "RGB": + img = img.convert("RGB") + + # Resize to 128x128 using high-quality Lanczos resampling + resized_img = img.resize((128, 128), Image.Resampling.LANCZOS) + + # Save resized image to bytes + output_buffer = BytesIO() + resized_img.save(output_buffer, format="JPEG", quality=95) + resized_bytes = output_buffer.getvalue() + + # Encode back to base64 string + row["image_base64"] = base64.b64encode(resized_bytes).decode("ascii") + return row + except Exception as e: + # If resize fails, keep original image + return row def fetch_images_batch_threaded(batch): - with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: - results = list(executor.map(fetch_image, batch["url"])) - batch["image_bytes"] = [result[0] for result in results] + """Fetch images in parallel with increased concurrency for network throughput.""" + # Disable SSL warnings in each Ray worker process + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) + + # Create a single shared session across threads in this batch + # Note: requests.Session is thread-safe for reading + session = create_session() + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_concurrent_downloads + ) as executor: + # Pass session to each fetch_image call + results = list( + executor.map(lambda url: fetch_image(url, session), batch["url"]) + ) + + batch["image_base64"] = [ + ( + base64.b64encode(result[0]).decode("ascii") + if (result[0] is not None and result[1]) + else None + ) + for result in results + ] batch["success"] = [result[1] for result in results] - batch["error"] = [result[2] for result in results] return batch @@ -71,7 +240,8 @@ def fetch_images_batch_threaded(batch): VLLM_DISABLE_COMPILE_CACHE="1", ), ), - batch_size=16, + batch_size=8, # Reduced from 16 to lower memory usage + max_concurrent_batches=16, # Increased to saturate vLLM engine (8 * 16 = 128) accelerator_type="A10G", concurrency=num_model_replicas, has_image=True, @@ -79,6 +249,11 @@ def fetch_images_batch_threaded(batch): def vision_preprocess(row: dict) -> dict: + # Keep image data as base64 string for Arrow serialization + # The vLLM engine will handle the conversion internally + image_data = row["image_base64"] + + image_data = f"data:image;base64,{image_data}" return dict( messages=[ { @@ -86,12 +261,9 @@ def vision_preprocess(row: dict) -> dict: "content": [ { "type": "image", - # Ray Data accepts PIL Image or image URL. - # "image": row["pil_image"], - "image": Image.open(BytesIO(row["image_bytes"])) - # "image": row["image_bytes"], + "image": image_data, }, - ] + ], }, ], sampling_params=dict( @@ -103,6 +275,7 @@ def vision_preprocess(row: dict) -> dict: def vision_postprocess(row: dict) -> dict: + row.pop("image_base64") return row @@ -113,33 +286,75 @@ def vision_postprocess(row: dict) -> dict: ) +# Initialize Ray with S3 spilling configuration +# This ensures Ray can access S3 for object spilling on all workers +if not ray.is_initialized(): + ray.init() + ray.data.DataContext.get_current().retried_io_errors.extend( [ + # Network connectivity errors "Temporary failure in name resolution", + "Name or service not known", "Max retries exceeded with url", "Failed to establish a new connection", + "Connection refused", + "Connection timed out", + "Read timed out", + "ConnectTimeoutError", + "connect timeout", "HTTPSConnectionPool", + "Remote end closed connection", + "Connection broken", + # SSL/TLS errors + "SSLError", + "SSL: CERTIFICATE_VERIFY_FAILED", + "hostname mismatch", + "certificate verify failed", + # Rate limiting + "429 Client Error: Too Many Requests", + "We had to rate limit you", ] ) +num_cpu = 512 +tasks_per_cpu = 1 +concurrency = num_cpu * tasks_per_cpu +ctx = ray.data.DataContext.get_current() +target_block_size_mb = 128 +ctx.target_max_block_size = target_block_size_mb * 1024 * 1024 +ctx.use_push_based_shuffle = False -dataset = ray.data \ - .read_parquet( +# Data pipeline with scalability optimizations +dataset = ( + ray.data.read_parquet( "hf://datasets/laion/relaion2B-en-research-safe/", file_extensions=["parquet"], + columns=["url"], filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), - ray_remote_args={"memory": 10*10**9}, - ) \ - .limit(num_images) \ - .repartition(target_num_rows_per_block=1000) \ + concurrency=concurrency, + num_cpus=2, + memory=int(4 * 1024**3), + ) .map_batches( fetch_images_batch_threaded, - batch_size=200, - ) \ - .filter(lambda row: row["success"]) \ - .filter(lambda row: Image.open(BytesIO(row["image_bytes"])).format == "JPEG") + batch_size=50, + memory=int(2 * 1024**3), + num_cpus=2, + ) # removed partition to reduce memory usage and redundency + .filter(lambda row: row["success"]) + .filter(is_jpeg_format) + .map(resize_image) + .drop_columns(["success"]) +) # Drop success column early to reduce memory + +# Apply vision processing with scaled replicas +# Note: image_base64 column is dropped in vision_postprocess to avoid Arrow serialization issues dataset = vision_processor(dataset) -dataset = dataset.drop_columns(["image_bytes"]) -dataset.write_parquet(output_path) +# Write with optimizations for throughput and fault tolerance +dataset.write_parquet( + output_path, + max_rows_per_file=100000, # ~100K rows per file for manageable file sizes +) diff --git a/image_processing/run.sh b/image_processing/run.sh new file mode 100755 index 0000000..8758376 --- /dev/null +++ b/image_processing/run.sh @@ -0,0 +1,2 @@ +anyscale job submit -f job.yaml \ + --env HF_TOKEN=$HF_TOKEN \ No newline at end of file From 812c8766bb5f66f3af19b261ef2b6c20c9732071 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 12 Sep 2025 16:57:19 -0400 Subject: [PATCH 04/15] updates --- image_processing/process_images.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/image_processing/process_images.py b/image_processing/process_images.py index a6987ce..3bb2ba8 100644 --- a/image_processing/process_images.py +++ b/image_processing/process_images.py @@ -8,9 +8,9 @@ from io import BytesIO -num_images = 1000000 -num_model_replicas = 64 -tensor_parallelism = 4 +num_images = 100 +num_model_replicas = 1 +tensor_parallelism = 1 output_path = os.path.join(os.environ["ANYSCALE_ARTIFACT_STORAGE"], "rkn/process_images_output") @@ -47,8 +47,7 @@ def convert_to_pil_image(row): def fetch_images_batch_threaded(batch): - # Previously used 50 instead of 250 - with concurrent.futures.ThreadPoolExecutor(max_workers=250) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: results = list(executor.map(fetch_image, batch["url"])) batch["image_bytes"] = [result[0] for result in results] batch["success"] = [result[1] for result in results] @@ -69,6 +68,7 @@ def fetch_images_batch_threaded(batch): runtime_env=dict( env_vars=dict( VLLM_USE_V1="1", + VLLM_DISABLE_COMPILE_CACHE="1", ), ), batch_size=16, @@ -131,16 +131,15 @@ def vision_postprocess(row: dict) -> dict: ray_remote_args={"memory": 10*10**9}, ) \ .limit(num_images) \ - .repartition(target_num_rows_per_block=5000) \ + .repartition(target_num_rows_per_block=1000) \ .map_batches( fetch_images_batch_threaded, - batch_size=1000, - # ray_remote_args_fn=lambda: { - # "memory": 2 * 10**9 - # }, + batch_size=200, ) \ .filter(lambda row: row["success"]) \ .filter(lambda row: Image.open(BytesIO(row["image_bytes"])).format == "JPEG") dataset = vision_processor(dataset) -dataset.write_parquet(output_path) \ No newline at end of file +dataset = dataset.drop_columns(["image_bytes"]) +dataset.write_parquet(output_path) + From 75c025a86b7cf7930c07981a1245d12699e37dc7 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Tue, 18 Nov 2025 19:17:04 -0800 Subject: [PATCH 05/15] Add megatron_ray_fault_tolerant example with comprehensive fault tolerance - Implements PPO-style training with Megatron and Ray - Features automatic actor recovery from failures - Includes backup actor pool for seamless replacement - Supports DP, TP, PP, and CP parallelism - Distributed checkpoint saving/loading - Process group re-initialization after failures - Added comprehensive documentation in README files --- README.md | 102 ++ megatron_ray_fault_tolerant/.gitignore | 1 + .../.pre-commit-config.yaml | 20 + megatron_ray_fault_tolerant/Dockerfile | 34 + megatron_ray_fault_tolerant/README.md | 191 ++++ megatron_ray_fault_tolerant/dispatch.py | 299 ++++++ megatron_ray_fault_tolerant/file_io.py | 321 ++++++ megatron_ray_fault_tolerant/job.yaml | 45 + megatron_ray_fault_tolerant/main.py | 190 ++++ megatron_ray_fault_tolerant/megatron_actor.py | 934 ++++++++++++++++++ .../megatron_model_utils.py | 442 +++++++++ .../megatron_model_wrapper.py | 171 ++++ megatron_ray_fault_tolerant/megatron_utils.py | 465 +++++++++ megatron_ray_fault_tolerant/optimizer.py | 103 ++ megatron_ray_fault_tolerant/pyproject.toml | 98 ++ megatron_ray_fault_tolerant/run.sh | 1 + megatron_ray_fault_tolerant/training_batch.py | 371 +++++++ megatron_ray_fault_tolerant/utils.py | 286 ++++++ 18 files changed, 4074 insertions(+) create mode 100644 README.md create mode 100644 megatron_ray_fault_tolerant/.gitignore create mode 100644 megatron_ray_fault_tolerant/.pre-commit-config.yaml create mode 100644 megatron_ray_fault_tolerant/Dockerfile create mode 100644 megatron_ray_fault_tolerant/README.md create mode 100644 megatron_ray_fault_tolerant/dispatch.py create mode 100644 megatron_ray_fault_tolerant/file_io.py create mode 100644 megatron_ray_fault_tolerant/job.yaml create mode 100644 megatron_ray_fault_tolerant/main.py create mode 100644 megatron_ray_fault_tolerant/megatron_actor.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_utils.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_wrapper.py create mode 100644 megatron_ray_fault_tolerant/megatron_utils.py create mode 100644 megatron_ray_fault_tolerant/optimizer.py create mode 100644 megatron_ray_fault_tolerant/pyproject.toml create mode 100755 megatron_ray_fault_tolerant/run.sh create mode 100644 megatron_ray_fault_tolerant/training_batch.py create mode 100644 megatron_ray_fault_tolerant/utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..3a7e424 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# Examples + +This repository contains examples for deploying and running distributed applications. + +## Job Examples + +### 1. Hello World Job +**Directory:** `01_job_hello_world/` + +A simple "Hello World" example demonstrating how to submit and run basic jobs. + +### 2. Image Processing +**Directory:** `image_processing/` + +Process large-scale image datasets using Ray Data. This example demonstrates processing the ReLAION-2B dataset with over 2 billion rows. + +### 3. Megatron + Ray Fault Tolerant Training +**Directory:** `megatron_ray_fault_tolerant/` + +Implements PPO-style distributed training with Megatron and Ray, featuring comprehensive fault tolerance capabilities: +- Automatic actor recovery from failures +- Backup actor groups for seamless replacement +- Distributed checkpoint saving/loading +- Process group re-initialization after failures +- Support for tensor, pipeline, data, and context parallelism + +## Service Examples + +### 1. Hello World Service +**Directory:** `02_service_hello_world/` + +A simple service deployment example demonstrating the basics of Ray Serve. + +### 2. Deploy Llama 3.1 8B +**Directory:** `03_deploy_llama_3_8b/` + +Deploy Llama 3.1 8B model using Ray Serve and vLLM with autoscaling capabilities. + +### 3. Deploy Llama 3.1 70B +**Directory:** `deploy_llama_3_1_70b/` + +Deploy the larger Llama 3.1 70B model with optimized serving configuration. + +### 4. Tensor Parallel Serving +**Directory:** `serve_tensor_parallel/` + +Demonstrates tensor parallelism for serving large language models across multiple GPUs. + +### 5. FastVideo Generation +**Directory:** `video_generation_with_fastvideo/` + +Deploy a video generation service using the FastVideo framework. + +## Reinforcement Learning Examples + +### SkyRL +**Directory:** `skyrl/` + +Reinforcement learning training example using Ray and distributed computing. + +## Getting Started + +Most examples include their own README with specific instructions. Generally, you'll need: + +1. Install the Anyscale CLI: +```bash +pip install -U anyscale +anyscale login +``` + +2. Navigate to the example directory: +```bash +cd +``` + +3. Deploy the service or submit the job: +```bash +# For services +anyscale service deploy -f service.yaml + +# For jobs +anyscale job submit -f job.yaml +``` + +## Requirements + +- Anyscale account and CLI access +- Appropriate cloud credentials configured +- GPU resources for ML/LLM examples + +## Contributing + +When adding new examples: +1. Create a descriptive directory name +2. Include a README.md with setup and usage instructions +3. Add appropriate YAML configuration files +4. Update this main README with your example + +## License + +See individual example directories for specific licensing information. + diff --git a/megatron_ray_fault_tolerant/.gitignore b/megatron_ray_fault_tolerant/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/megatron_ray_fault_tolerant/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/.pre-commit-config.yaml b/megatron_ray_fault_tolerant/.pre-commit-config.yaml new file mode 100644 index 0000000..5d51437 --- /dev/null +++ b/megatron_ray_fault_tolerant/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.9 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + exclude: (^(skyagent)/.*)$ + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + exclude: (^(skyagent)/.*)$ + + # Detect secrets and sensitive information + - repo: https://github.com/gitleaks/gitleaks + rev: v8.24.2 + hooks: + - id: gitleaks \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/Dockerfile b/megatron_ray_fault_tolerant/Dockerfile new file mode 100644 index 0000000..787c1c7 --- /dev/null +++ b/megatron_ray_fault_tolerant/Dockerfile @@ -0,0 +1,34 @@ +FROM anyscale/ray:2.51.0-slim-py312-cu128 + +RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev + +# the cuda compiler here is needed for deepspeed +RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \ + && sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run + +RUN curl -LsSf https://astral.sh/uv/0.9.4/install.sh | sh +RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc + + +RUN sudo apt-get update \ + && sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \ + libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \ + && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \ + && sudo apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +# ---------- PyTorch + cuDNN + Transformer Engine ---------- +# PyTorch + cuDNN + Transformer Engine +RUN pip install --no-cache-dir "torch==2.7.1" "nvidia-cudnn-cu12>=9.3" && \ + CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" && \ + sudo mkdir -p /opt && sudo ln -sfn "$CUDNN_PATH" /opt/cudnn && \ + echo "/opt/cudnn/lib" | sudo tee /etc/ld.so.conf.d/cudnn.conf >/dev/null && sudo ldconfig + +ENV CUDNN_PATH=/opt/cudnn +ENV CPATH=${CUDNN_PATH}/include:${CPATH} +ENV LD_LIBRARY_PATH=${CUDNN_PATH}/lib:${LD_LIBRARY_PATH} + +RUN pip install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.5.0" +# -------------------- diff --git a/megatron_ray_fault_tolerant/README.md b/megatron_ray_fault_tolerant/README.md new file mode 100644 index 0000000..b7abecc --- /dev/null +++ b/megatron_ray_fault_tolerant/README.md @@ -0,0 +1,191 @@ +# Megatron + Ray Fault Tolerant Training + +This example implements PPO-style distributed training using Megatron and Ray with comprehensive fault tolerance capabilities. The system can automatically recover from actor failures during training by utilizing backup actors and re-initializing process groups. + +## Key Features + +### Fault Tolerance Mechanisms + +1. **Actor Health Monitoring**: Continuously monitors the health of distributed training actors +2. **Backup Actor Pool**: Pre-allocated backup actors ready to replace failed workers +3. **Automatic Recovery**: Seamlessly recovers from failures by: + - Detecting dead actors + - Destroying old process groups + - Replacing failed actors with backup actors + - Re-initializing process groups with new world size + - Reloading model and optimizer state from checkpoints + +4. **Distributed Checkpointing**: Implements efficient sharded checkpoint saving/loading using Megatron's distributed checkpointing +5. **Process Group Management**: Handles NCCL process group initialization, destruction, and re-initialization + +### Parallelism Support + +- **Data Parallelism (DP)**: Distributes training data across multiple GPUs +- **Tensor Parallelism (TP)**: Splits model tensors across GPUs +- **Pipeline Parallelism (PP)**: Distributes model layers across GPUs +- **Context Parallelism (CP)**: Enables sequence parallelism for long contexts + +### Advanced Training Features + +- **PPO Training**: Implements Proximal Policy Optimization with micro-batch accumulation +- **Mixed Precision**: Supports BF16 training for improved performance +- **Gradient Accumulation**: Handles micro-batches with automatic gradient accumulation +- **Distributed Optimizer**: Uses Megatron's distributed optimizer for memory efficiency + +## Architecture + +### Core Components + +1. **MegatronActor** (`megatron_actor.py`): + - Individual training actor wrapping Megatron models + - Handles model initialization, forward/backward passes, and checkpointing + - Supports dynamic process group re-initialization + +2. **MegatronActorGroup** (`megatron_actor.py`): + - Manages a group of distributed actors + - Implements fault recovery logic + - Coordinates distributed training operations + +3. **Dispatch System** (`dispatch.py`): + - **MeshDispatch**: Distributes data across the device mesh (DP, SP, TP, PP) + - **PassThroughDispatch**: Broadcasts same data/commands to all actors + - Handles data sharding and result collection + +4. **Training Batch** (`training_batch.py`): + - Defines input/output batch structures for PPO training + - Supports chunking and concatenation for distributed operations + +5. **Checkpoint I/O** (`file_io.py`): + - Cloud-aware file I/O supporting S3, GCS, and local storage + - Efficient checkpoint upload/download with parallel transfers + +## Getting Started + +### Quick Start + +```bash +uv run --isolated main.py +``` + +This will: +1. Create a placement group with workers and backup GPUs +2. Initialize the actor group and model +3. Run a training step +4. Save a checkpoint +5. Simulate a failure by killing actors +6. Recover from the failure using backup actors +7. Resume training after recovery + +### Configuration + +Edit the `Config` class in `main.py` to customize: + +```python +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" # HuggingFace model name + num_nodes: int = 1 + num_gpus_per_node: int = 4 + num_spare_gpus: int = 4 # Backup actors for fault tolerance + mini_batch_size: int = 16 + micro_train_batch_size_per_gpu: int = 2 + + # Megatron parallelism settings + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) +``` + +### Megatron Parallelism Configuration + +```python +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 # TP degree + pipeline_model_parallel_size: int = 1 # PP degree + context_parallel_size: int = 1 # CP degree + expert_model_parallel_size: int = 1 # For MoE models +``` + +## Fault Recovery Workflow + +1. **Training Phase**: + - Actors perform distributed training using Megatron + - Periodic checkpoints saved to cloud storage + +2. **Failure Detection**: + - System detects actor failures via health checks + - Identifies affected data parallel groups + +3. **Recovery Process**: + - Destroy old process groups on healthy actors + - Pop backup actors from the backup pool + - Insert backup actors at failed ranks + - Update world size and reassign ranks + - Re-initialize process groups with new configuration + - Reload model/optimizer state from checkpoint + +4. **Resume Training**: + - Continue training with recovered actor group + - No loss of training progress (from last checkpoint) + +## Advanced Usage + +### Custom Dispatch Types + +Register custom dispatch strategies: + +```python +from dispatch import register_dispatch_type, Dispatch + +class CustomDispatch(Dispatch): + # Implement dispatch, collect, and validate methods + pass + +register_dispatch_type("custom", CustomDispatch) +``` + +### CPU Offloading (Experimental) + +For faster recovery, offload model/optimizer state to CPU memory: + +```python +# Before failure +ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + +# After recovery, on healthy actors +ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) +``` + +## Dependencies + +See `pyproject.toml` for full dependency list. Key dependencies: +- Ray for distributed orchestration +- Megatron-Core for model parallelism +- PyTorch with CUDA support +- Transformers for model loading +- vLLM and related libraries + +## Running on Anyscale + +Submit the job using: + +```bash +anyscale job submit -f job.yaml +``` + +The job configuration in `job.yaml` specifies: +- Container image with dependencies +- GPU instance types (g6e.12xlarge with 4xL4) +- Resource limits and scaling +- Environment variables for NCCL configuration + +## Limitations and Future Work + +- Virtual pipeline parallelism not yet supported +- CPU offloading optimization in progress +- Async checkpoint saving planned for future releases + +## References + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +- [Ray Documentation](https://docs.ray.io/) +- [Anyscale Platform](https://docs.anyscale.com/) diff --git a/megatron_ray_fault_tolerant/dispatch.py b/megatron_ray_fault_tolerant/dispatch.py new file mode 100644 index 0000000..9949c48 --- /dev/null +++ b/megatron_ray_fault_tolerant/dispatch.py @@ -0,0 +1,299 @@ +"""Defines dispatch and collect logic for distributed training""" + +from dataclasses import dataclass +from ray.actor import ActorHandle +from typing import List, Tuple, Optional, Dict, Type, Any +import asyncio +from abc import ABC, abstractmethod +import ray +from ray import ObjectRef +from training_batch import TrainingInputBatch, TrainingOutputBatch +import inspect + + +@dataclass +class MeshRank: + """Represents a rank in the device mesh. + + This is a tuple of (DP, SP, TP, PP) ranks. + """ + + dp: int + sp: int + tp: int + pp: int + + world_size: int + dp_size: int + pp_size: int + + def is_collection_dp_rank(self) -> bool: + """Check if this rank is a DP rank to collect from + + This is the rank with (SP=0, TP=0, PP=pp_size-1) + + Note: double check this for ETP > 1 (but this is not a typically used case) + """ + return self.tp == 0 and self.pp == self.pp_size - 1 and self.sp == 0 + + def __str__(self) -> str: + return f"MeshRank(dp={self.dp}, sp={self.sp}, tp={self.tp}, pp={self.pp}, world_size={self.world_size}, dp_size={self.dp_size}, pp_size={self.pp_size})" + + def __repr__(self) -> str: + return self.__str__() + + +@dataclass +class ActorInfo: + """Actor information for distributed training. + + This includes the actor handle and the rank in the device mesh. + """ + + handle: ActorHandle + rank: MeshRank + + +class Dispatch(ABC): + """Base class for dispatch types + + Dispatch types are responsible for: + - dispatching method calls to actors handling data sharding if necessary + - collecting results from actors and concatenating results if necessary + - validating arguments for dispatch + """ + + @classmethod + @abstractmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + """Dispatches method calls to the actors with data sharing if necessary.""" + pass + + @classmethod + @abstractmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors asynchronously in an asyncio-compatible way.""" + pass + + @classmethod + @abstractmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors synchronously and returns a `TrainingOutputBatch`.""" + pass + + @classmethod + @abstractmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + """Validate and process arguments for dispatch. + + Returns: + Tuple of (args, kwargs) to be passed to dispatch + """ + pass + + +class MeshDispatch(Dispatch): + """Mesh dispatch type to dispatch data to a group of actors along the device mesh. + + Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. + The actor method should accept a single argument - the data batch. + + For data dispatch: + + * The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism. + * Each actor with the same DP rank processes the same data chunk in parallel. + + For data collection: + + * Data is collected only from the primary rank of each model/sequence parallel group. + * The primary rank is defined as the rank with (SP=0, TP=0, PP=0). + * The collected chunks are concatenated in order of DP rank to reconstruct the full data. + + Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1: + + * Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, + and all actors with DP rank 1 process the second chunk. + * Data collection: Only two actors contribute to the final output - the primary rank from each DP group: + (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order. + + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch + ) -> List[ObjectRef]: + assert len(actor_infos) > 0, "actor_infos must be a non-empty list" + object_refs = [] + dp_size = actor_infos[0].rank.dp_size + assert ( + len(data) % dp_size == 0 + ), "data batch size must be divisible by dp_size, got {} and {}".format( + len(data), dp_size + ) + chunk_size = len(data) // dp_size + data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size) + + for actor_info in actor_infos: + # index into tensordict to get the correct data to send + data_to_send = data_chunks[actor_info.rank.dp] + object_refs.append(getattr(actor_info.handle, method).remote(data_to_send)) + return object_refs + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = ray.get(object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + # all should be none + assert all( + obj is None for obj in all_objects + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + sig = inspect.signature(cls.dispatch) + # pass dummy actor_infos and method_name + bound_args = sig.bind([], "dummy", *args, **kwargs) + bound_args.apply_defaults() + data = bound_args.arguments.get("data") + + # Check if there are any extra arguments + if len(bound_args.arguments) > 3: # data, actor_infos, method_name + # remove actor_infos and method_name - not added by user + bound_args.arguments.pop("actor_infos") + bound_args.arguments.pop("method") + raise ValueError( + f"MeshDispatch only accepts 'data' as an argument, got extra args: {bound_args.arguments}" + ) + + data = bound_args.arguments.get("data") + if not isinstance(data, TrainingInputBatch): + raise ValueError( + f"For MeshDispatch, `data` entry should be a `TrainingInput`, got {data}" + ) + args = (data,) + kwargs = {} + return args, kwargs + + +class PassThroughDispatch(Dispatch): + """PassThrough dispatch type to dispatch data to a group of actors without any sharding. + + This is useful for cases where we want to run the same method on all the actors. + Supports methods with any number of arguments. + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + return [ + getattr(actor_info.handle, method).remote(*args, **kwargs) + for actor_info in actor_infos + ] + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + data_batches = ray.get(object_refs) + if len(data_batches) > 0 and data_batches[0] is not None: + assert isinstance( + data_batches[0], TrainingOutputBatch + ), "data_batches must be a list of `TrainingOutputBatch` objects" + return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches) + # all should be none + assert all( + obj is None for obj in data_batches + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + # no validation needed just pass everything + return args, kwargs + + +class DispatchRegistry: + _registry: Dict[str, Type[Dispatch]] = { + "mesh": MeshDispatch, + "pass_through": PassThroughDispatch, + } + + @classmethod + def register(cls, name: str, dispatch_class: Type[Dispatch]) -> None: + """Register a new dispatch type.""" + assert issubclass(dispatch_class, Dispatch) + cls._registry[name] = dispatch_class + + @classmethod + def get(cls, name: str) -> Type[Dispatch]: + """Get a registered dispatch type.""" + if name not in cls._registry: + raise KeyError(f"Dispatch type '{name}' not registered") + return cls._registry[name] + + @classmethod + def list_registered(cls) -> Dict[str, Type[Dispatch]]: + """List all registered dispatch types.""" + return cls._registry + + +def register_dispatch_type(name: str, dispatch_class: Type) -> None: + DispatchRegistry.register(name, dispatch_class) + + +def concatenate_outputs_after_mesh_dispatch( + actor_infos: List[ActorInfo], data_batches: List[TrainingOutputBatch] +) -> TrainingOutputBatch: + """Concatenate data batches from different ranks after mesh dispatch. + + - Data is collected only from the primary DP rank. + - The collected chunks are concatenated in order of DP rank to reconstruct the full data. + """ + assert len(actor_infos) == len( + data_batches + ), "`actor_infos` and `data_batches` must have the same length" + shards = [] + # collect in-order + dp_rank_to_shard = {} + for actor_info, data_batch in zip(actor_infos, data_batches): + if actor_info.rank.is_collection_dp_rank(): + dp_rank = actor_info.rank.dp + dp_rank_to_shard[dp_rank] = data_batch + for i in range(actor_infos[0].rank.dp_size): + shards.append(dp_rank_to_shard[i]) + return TrainingOutputBatch.cat(shards) diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py new file mode 100644 index 0000000..932adbe --- /dev/null +++ b/megatron_ray_fault_tolerant/file_io.py @@ -0,0 +1,321 @@ +""" +File I/O utilities for handling both local filesystem and cloud storage (S3/GCS). + +This module provides a unified interface for file operations that works with: +- Local filesystem paths +- S3 paths (s3://bucket/path) +- Google Cloud Storage paths (gs://bucket/path or gcs://bucket/path) + +Uses fsspec for cloud storage abstraction. +""" + +import os +import tempfile +from contextlib import contextmanager +import fsspec +from loguru import logger +from datetime import datetime, timezone, timedelta + +# Optional AWS deps (present when s3fs is installed) +try: + import botocore.session as _botocore_session + from botocore.exceptions import ClientError + + _HAS_BOTOCORE = True +except Exception: + _HAS_BOTOCORE = False + + class ClientError(Exception): # fallback type + pass + + +_S3_FS = None # type: ignore + + +def get_s3_fs(): + """Return a cached S3 filesystem instance, creating it once.""" + global _S3_FS + if _S3_FS is None: + _S3_FS = fsspec.filesystem("s3") + return _S3_FS + + +def s3_expiry_time(): + """Return botocore credential expiry (datetime in UTC) or None.""" + if not _HAS_BOTOCORE: + return None + try: + sess = _botocore_session.get_session() + creds = sess.get_credentials() + if not creds: + return None + return getattr(creds, "expiry_time", None) or getattr( + creds, "_expiry_time", None + ) + except Exception: + return None + + +def s3_refresh_if_expiring(fs) -> None: + """ + Simple refresh: + - If expiry exists and is within 300s (or past), refresh with fs.connect(refresh=True). + - Otherwise, do nothing. + """ + exp = s3_expiry_time() + if not exp: + return + now = datetime.now(timezone.utc) + if now >= exp - timedelta(seconds=300): + try: + fs.connect(refresh=True) # rebuild session + except Exception: + pass + + +def call_with_s3_retry(fs, fn, *args, **kwargs): + """ + Wrapper for calling an S3 method. If it fails with ExpiredToken, force refresh once and retry. + """ + try: + return fn(*args, **kwargs) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fn(*args, **kwargs) + raise + + +def is_cloud_path(path: str) -> bool: + """Check if the given path is a cloud storage path.""" + return path.startswith(("s3://", "gs://", "gcs://")) + + +def _get_filesystem(path: str): + """Get the appropriate filesystem for the given path.""" + if not is_cloud_path(path): + return fsspec.filesystem("file") + + proto = path.split("://", 1)[0] + if proto == "s3": + fs = get_s3_fs() + s3_refresh_if_expiring(fs) + return fs + return fsspec.filesystem(proto) + + +def open_file(path: str, mode: str = "rb"): + """Open a file using fsspec, works with both local and cloud paths.""" + if not is_cloud_path(path): + return fsspec.open(path, mode) + + fs = _get_filesystem(path) + norm = fs._strip_protocol(path) + try: + return fs.open(norm, mode) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fs.open(norm, mode) + raise + + +def makedirs(path: str, exist_ok: bool = True) -> None: + """Create directories. Only applies to local filesystem paths.""" + if not is_cloud_path(path): + os.makedirs(path, exist_ok=exist_ok) + + +def exists(path: str) -> bool: + """Check if a file or directory exists.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.exists, path) + return fs.exists(path) + + +def isdir(path: str) -> bool: + """Check if path is a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.isdir, path) + return fs.isdir(path) + + +def list_dir(path: str) -> list[str]: + """List contents of a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.ls, path, detail=False) + return fs.ls(path, detail=False) + + +def remove(path: str) -> None: + """Remove a file or directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + if call_with_s3_retry(fs, fs.isdir, path): + call_with_s3_retry(fs, fs.rm, path, recursive=True) + else: + call_with_s3_retry(fs, fs.rm, path) + return + if fs.isdir(path): + fs.rm(path, recursive=True) + else: + fs.rm(path) + + +def upload_directory(local_path: str, cloud_path: str) -> None: + """Upload a local directory to cloud storage. + + Uploads the contents of local_path to cloud_path, not the directory itself. + This ensures consistent behavior across all ranks by explicitly uploading each file. + """ + if not is_cloud_path(cloud_path): + raise ValueError(f"Destination must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + + # Normalize paths: ensure cloud_path ends with / to indicate directory + cloud_path_normalized = cloud_path.rstrip("/") + "/" + + # Walk the local directory and upload each file explicitly + # This ensures we upload contents, not the directory as a subdirectory + for root, dirs, files in os.walk(local_path): + for file in files: + local_file_path = os.path.join(root, file) + # Get relative path from local_path to maintain directory structure + rel_path = os.path.relpath(local_file_path, local_path) + # Construct remote path: cloud_path/rel_path + remote_file_path = cloud_path_normalized + rel_path + + if cloud_path.startswith("s3://"): + # For S3, strip protocol for fsspec operations + remote_file_path_stripped = fs._strip_protocol(remote_file_path) + # Ensure parent directories exist in S3 (fsspec handles this automatically) + call_with_s3_retry( + fs, fs.put, local_file_path, remote_file_path_stripped + ) + else: + fs.put(local_file_path, remote_file_path) + + logger.info(f"Uploaded contents of {local_path} to {cloud_path}") + + +def download_directory(cloud_path: str, local_path: str) -> None: + """Download a cloud directory to local storage.""" + if not is_cloud_path(cloud_path): + raise ValueError(f"Source must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + cloud_path_normalized = cloud_path.rstrip("/") + "/" + os.makedirs(local_path, exist_ok=True) + + # List all files and download each one individually to download contents, not the folder + if cloud_path.startswith("s3://"): + remote_path_stripped = fs._strip_protocol(cloud_path_normalized) + all_files = call_with_s3_retry(fs, fs.find, remote_path_stripped, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(remote_path_stripped) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + call_with_s3_retry(fs, fs.get, remote_file, local_file_path) + else: + all_files = fs.find(cloud_path_normalized, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(cloud_path_normalized) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + fs.get(remote_file, local_file_path) + + logger.info(f"Downloaded {cloud_path} to {local_path}") + + +@contextmanager +def local_work_dir(output_path: str): + """ + Context manager that provides a local working directory. + + For local paths, returns the path directly. + For cloud paths, creates a temporary directory and uploads content at the end. + + Args: + output_path: The final destination path (local or cloud) + + Yields: + str: Local directory path to work with + + Example: + with local_work_dir("s3://bucket/model") as work_dir: + # Save files to work_dir + model.save_pretrained(work_dir) + # Files are automatically uploaded to s3://bucket/model at context exit + """ + if is_cloud_path(output_path): + with tempfile.TemporaryDirectory() as temp_dir: + try: + yield temp_dir + finally: + # Upload everything from temp_dir to cloud path + upload_directory(temp_dir, output_path) + logger.info(f"Uploaded directory contents to {output_path}") + else: + # For local paths, ensure directory exists and use it directly + makedirs(output_path, exist_ok=True) + yield output_path + + +@contextmanager +def local_read_dir(input_path: str): + """ + Context manager that provides a local directory with content from input_path. + + For local paths, returns the path directly. + For cloud paths, downloads content to a temporary directory. + + Args: + input_path: The source path (local or cloud) + + Yields: + str: Local directory path containing the content + + Example: + with local_read_dir("s3://bucket/model") as read_dir: + # Load files from read_dir + model = AutoModel.from_pretrained(read_dir) + """ + if is_cloud_path(input_path): + with tempfile.TemporaryDirectory() as temp_dir: + # Download everything from cloud path to temp_dir + download_directory(input_path, temp_dir) + logger.info(f"Downloaded directory contents from {input_path}") + yield temp_dir + else: + # For local paths, use directly (but check it exists) + if not exists(input_path): + raise FileNotFoundError(f"Path does not exist: {input_path}") + yield input_path diff --git a/megatron_ray_fault_tolerant/job.yaml b/megatron_ray_fault_tolerant/job.yaml new file mode 100644 index 0000000..f1c2de2 --- /dev/null +++ b/megatron_ray_fault_tolerant/job.yaml @@ -0,0 +1,45 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. + +name: megatron-fault-tolerance + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: + # Pin worker nodes to g6.xlarge (1xL4) so the vision workload lands on L4 GPUs. + worker_nodes: + - instance_type: g6e.12xlarge + min_nodes: 0 + max_nodes: 2 + min_resources: + CPU: 0 + GPU: 0 + max_resources: + CPU: 384 + GPU: 64 + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +env_vars: + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + NCCL_P2P_DISABLE: "1" + NCCL_SHM_DISABLE: "1" + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: uv run --isolated main.py + +# If there is an error, do not retry. +max_retries: 0 \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py new file mode 100644 index 0000000..b64b535 --- /dev/null +++ b/megatron_ray_fault_tolerant/main.py @@ -0,0 +1,190 @@ +import os +from dataclasses import dataclass, field +import ray +from typing import Optional, List +from megatron_actor import MegatronActorGroup +from ray.util.placement_group import placement_group + +import random +import time +from utils import get_test_training_batch, get_reordered_bundle_indices + + +@dataclass +class DDPConfig: + grad_reduce_in_fp32: bool = True + overlap_grad_reduce: bool = False + overlap_param_gather: bool = False + average_in_collective: bool = True + + +@dataclass +class OptimizerConfig: + lr: float = 1.0e-6 + weight_decay: float = 1e-2 + max_grad_norm: float = 1.0 + offload_after_step: bool = True + num_warmup_steps: int = 0 + scheduler: str = "constant_with_warmup" + + +@dataclass +class TransformerConfig: + recompute_granularity: Optional[str] = None + recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"]) + recompute_method: Optional[str] = None + recompute_num_layers: Optional[int] = None + + +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + context_parallel_size: int = 1 + expert_model_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + ddp_config: DDPConfig = field(default_factory=DDPConfig) + optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) + transformer_config: TransformerConfig = field(default_factory=TransformerConfig) + + +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" + # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it + num_nodes: int = 1 + num_gpus_per_node: int = 4 + mini_batch_size: int = 16 + num_spare_gpus: int = 4 + micro_train_batch_size_per_gpu: int = 2 + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) + ckpt_dir: str = ( + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt3/" + ) + # algorithm config + eps_clip_low: float = 0.2 + eps_clip_high: float = 0.2 + clip_ratio_c: float = 3.0 + + +def main(): + config = Config() + # create placement group including spare gpus + pg = placement_group( + [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node + + [{"GPU": 1, "CPU": 1}] * config.num_spare_gpus, + strategy="PACK", + ) + ray.get(pg.ready(), timeout=1200) + # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 + reordered_bundle_indices = get_reordered_bundle_indices(pg) + + actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_nodes, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[:-config.num_spare_gpus], + ) + actor_group.initiate_worker_process_group() + ray.get(actor_group.async_init_model(config.model)) + + # potentially need some time for dependencies like transformer-engine-torch to build on worker nodes (this is something good to warm start...) + backup_actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_spare_gpus // config.num_gpus_per_node, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[-config.num_spare_gpus:], + ) + # just place but don't initiate the worker process group for the backup actor group + # call a function to make sure the actors are placed + ray.get(backup_actor_group.async_run_method_no_dispatch("get_gpu_id")) + + # train on one batch + batch = get_test_training_batch(config.model, batch_size=32) + print("Starting training step 1...") + start_time = time.time() + ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) + print(f"Training step 1 took {time.time() - start_time:.2f} seconds") + + # save checkpoint + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") + + # TODO: add a cpu offload (or cpu save memory) call here + # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + + # TODO: run another training batch here and save results but don't save checkpoint + + # randomly kill an actor to simulate fault tolerance scenario + # TODO: go deeper into the actor code and throw an exception on a given node and catch it here + print("Simulating failure and recovery...") + start_time = time.time() + + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) + # get the whole dp group associated with the failed actor + dp_group_actors = [] + for actor_info in actor_group.actor_infos: + if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: + dp_group_actors.append(actor_info) + print( + f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." + ) + for actor_info in dp_group_actors: + ray.kill(actor_info.handle) + + # Destroy process groups on all actors (including dead ones, which will fail gracefully) + print("Destroying old process groups...") + try: + ray.get( + actor_group.async_run_ray_method( + "pass_through", "destroy_worker_process_group" + ) + ) + except Exception as e: + print(f"Some actors failed during destroy (expected): {e}") + + for i, actor_info in enumerate(actor_group.actor_infos): + is_alive = actor_group._check_actor_alive(actor_info.handle) + print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") + + # Recover from failure: remove dead actors and re-initialize process group + print("Recovering from actor failure...") + actor_group.recover_from_failure(backup_actor_group) + + # load checkpoint on all actors + # TODO: improve the logic here + # we want to only call load checkpoint on the actors that are fresh + # on previously healthy actors we want to restore weights and optimizer state from cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu"), actor_ids=[previously healthy actor ids]) + # only for new actors, we want to load the checkpoint + ray.get( + actor_group.async_run_ray_method( + "pass_through", "load_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Recovery took {time.time() - start_time:.2f} seconds") + + # TODO: check that results here are the same as before the failure when resuming from checkpoint + # Test that training still works after recovery + print("Testing training after recovery...") + batch_after_recovery = get_test_training_batch(config.model, batch_size=32) + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "ppo_train", batch_after_recovery + ) + ) + print(f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds") + print("Recovery successful! Training works with remaining actors.") + + +if __name__ == "__main__": + main() diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py new file mode 100644 index 0000000..c1789de --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -0,0 +1,934 @@ +import logging +import os +import random +import socket +from dataclasses import asdict +from tqdm import tqdm +from typing import Optional, Dict, Any, List +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +import ray +from ray import ObjectRef +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer +from loguru import logger + +# megatron +from megatron.bridge import AutoBridge +import megatron.core.parallel_state as mpu +from megatron.core import dist_checkpointing +from megatron.core.dist_checkpointing.strategies import base as ckpt_base +from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) + +# local imports +import file_io as io # local io module to support cloud storage for checkpointing +from training_batch import TrainingOutputBatch +from optimizer import ( + init_megatron_optim_config, + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, +) +from megatron_model_wrapper import MegatronModelWrapper +from megatron_utils import ( + offload_megatron_model_to_cpu, + offload_megatron_optimizer, + load_megatron_model_to_gpu, + load_megatron_optimizer, + offload_megatron_grads_to_cpu, + load_megatron_grads_to_gpu, +) +from utils import BatchIterator +from dispatch import DispatchRegistry, Dispatch, ActorInfo, MeshRank + + +@ray.remote(num_gpus=1) +class MegatronActor: + def __init__( + self, + world_size, + rank, + local_rank, + master_addr, + master_port, + megatron_config, + seed, + cfg, + ): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self._world_size = world_size + self._rank = rank + self._local_rank = local_rank + self._master_addr = master_addr if master_addr else self._get_current_node_ip() + self._master_port = master_port if master_port else self._get_free_port() + os.environ["MASTER_ADDR"] = self._master_addr + os.environ["MASTER_PORT"] = str(self._master_port) + os.environ["WORLD_SIZE"] = str(self._world_size) + os.environ["RANK"] = str(self._rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + os.environ["LOCAL_RANK"] = "0" + self.megatron_config = megatron_config + self.seed = seed + self.cfg = cfg + + def get_node_local_rank(self): + return self._local_rank + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if torch.cuda.device_count() > 0: + from megatron.core import tensor_parallel + + tensor_parallel.model_parallel_cuda_manual_seed(seed) + + def init_worker_process_group(self): + """Initialize worker process group and megatron model parallel.""" + # Destroy any existing process group first to ensure clean state + if torch.distributed.is_initialized(): + try: + torch.distributed.destroy_process_group() + except Exception: + pass # Ignore errors if already destroyed + + # Initialize process group using environment variables + torch.distributed.init_process_group(backend="nccl") + + local_rank = int(os.environ.get("LOCAL_RANK", "-1")) + if local_rank != -1: + torch.cuda.set_device(local_rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.megatron_config.tensor_model_parallel_size, + pipeline_model_parallel_size=self.megatron_config.pipeline_model_parallel_size, + expert_model_parallel_size=self.megatron_config.expert_model_parallel_size, + expert_tensor_parallel_size=self.megatron_config.expert_tensor_parallel_size, + use_sharp=False, + context_parallel_size=self.megatron_config.context_parallel_size, + nccl_communicator_config_path=None, + ) + self.set_seed(self.seed) + self.world_size = dist.get_world_size() + self.mesh_rank = MeshRank( + dp=mpu.get_data_parallel_rank(), + sp=mpu.get_context_parallel_rank(), + tp=mpu.get_tensor_model_parallel_rank(), + pp=mpu.get_pipeline_model_parallel_rank(), + world_size=self._world_size, + dp_size=mpu.get_data_parallel_world_size(), + pp_size=mpu.get_pipeline_model_parallel_world_size(), + ) + + def get_mesh_rank(self): + return self.mesh_rank + + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + def print(self, *msg): + """Print only on rank 0""" + if dist.get_rank() == 0: + logger.info(*msg) + + @staticmethod + def _get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + def get_ray_node_id(self): + return ray.get_runtime_context().get_node_id() + + @staticmethod + def get_rng_state(): + """Get current RNG state for reproducibility""" + rng_state = { + "cpu": torch.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + + # Only save CUDA RNG state if CUDA is available and being used + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + rng_state["cuda"] = torch.cuda.get_rng_state() + + return rng_state + + @staticmethod + def load_rng_state(rng_state): + """Load RNG state for reproducibility""" + torch.set_rng_state(rng_state["cpu"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + + # Only restore CUDA RNG state if it was saved and CUDA is available + if ( + "cuda" in rng_state + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + torch.cuda.set_rng_state(rng_state["cuda"]) + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + def destroy_worker_process_group(self): + mpu.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Clear stale env vars + for env_var in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]: + if env_var in os.environ: + del os.environ[env_var] + + def reinit_model_after_recovery(self): + """Re-initialize model and optimizer after process group recovery. + + This is needed because the model and optimizer were created with the old + process group and still have references to old NCCL communicators. + + We need to fully reinitialize the provider and model to ensure they use + the new process group. + """ + if not hasattr(self, "_model_path") or self._model_path is None: + # Fall back to cfg.model if _model_path not set + if hasattr(self.cfg, "model"): + model_path = self.cfg.model + else: + logger.warning("No model path found, cannot re-initialize model") + return + else: + model_path = self._model_path + + num_training_steps = getattr(self, "_num_training_steps", 1e9) + + logger.info("Re-initializing model components after process group recovery...") + + # Re-initialize the bridge and provider with the new process group + # This ensures all NCCL communicators are created fresh + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # Recreate the DDP-wrapped module with the new process group + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + # Recreate optimizer with the new process group + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + # Recreate scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # Recreate model wrapper + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # Re-normalize mini batch size with new world size + self._normalize_mini_batch_size() + + logger.info("Model components re-initialized successfully") + + def update_world_size(self, new_world_size: int): + """Update the world_size stored in the actor.""" + self._world_size = new_world_size + os.environ["WORLD_SIZE"] = str(new_world_size) + + def update_rank(self, new_rank: int): + """Update the rank stored in the actor.""" + self._rank = new_rank + os.environ["RANK"] = str(new_rank) + + def update_master_addr_port(self, master_addr: str, master_port: int): + """Update the master address and port for process group initialization.""" + self._master_addr = master_addr + self._master_port = master_port + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + def _normalize_mini_batch_size(self): + """ + Normalize mini batch sizes to per-gpu mini batch sizes. + """ + if not hasattr(self, "mesh_rank") or self.mesh_rank is None: + raise RuntimeError( + "mesh_rank must be initialized before calling _normalize_mini_batch_size()" + ) + + dp_size = self.mesh_rank.dp_size + self.policy_mini_batch_size_per_gpu = self.cfg.mini_batch_size // dp_size + + def ppo_train(self, train_data) -> "TrainingOutputBatch": + """ + Overrides `PolicyWorkerBase.ppo_train` for megatron. + + Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the + worker MegatronModelWrapper.forward_backward_mini_batch method. + """ + dataloader = BatchIterator( + train_data, + sample_batch_size=self.cfg.micro_train_batch_size_per_gpu, + drop_last=False, + ) + + micro_batches_per_mini_batch = ( + self.policy_mini_batch_size_per_gpu + // self.cfg.micro_train_batch_size_per_gpu + ) + + self.optimizer.zero_grad() + pbar = tqdm( + dataloader, + desc="ppo train", + disable=not dist.get_rank() == 0, + ) + + micro_buffer = [] + for local_step, experience in enumerate(pbar): + experience.to_device(torch.cuda.current_device()) + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + } + ) + + if len(micro_buffer) == micro_batches_per_mini_batch: + # run mini-batch forward-backward and then one optimizer step + self.model.train() + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + seq_len = micro_buffer[0]["sequences"].shape[1] + micro_bsz = micro_buffer[0]["sequences"].shape[0] + + self.model.forward_backward_mini_batch( + micro_batches=micro_buffer, + seq_len=seq_len, + micro_batch_size=micro_bsz, + ) + + _, grad_norm, _ = self.optimizer.step() + self.scheduler.step(1) + self.optimizer.zero_grad() + + torch.distributed.barrier() + + def save_checkpoint(self, ckpt_dir: str): + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + node_local_rank = self.get_node_local_rank() + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + model = model[0] + if hasattr(model, "module"): + model = model.module + + # Create checkpoint directory if it doesn't exist. + if node_local_rank == 0: + io.makedirs(ckpt_dir, exist_ok=True) + + # All ranks wait for the checkpoint directory to be created before saving. + dist.barrier() + + # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. + sharded_state_dict = {} + model_sharded_state_dict = model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # Save RNG state. + sharded_state_dict["rng"] = self.get_rng_state() + + # Save the checkpoint across ranks in parallel. + save_strategy = get_default_save_sharded_strategy("torch_dist") + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + with io.local_work_dir(ckpt_dir) as work_dir: + # synchronous checkpointing for now + async_save_request = dist_checkpointing.save( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=work_dir, + sharded_strategy=save_strategy, + async_sharded_save=False, + validate_access_integrity=True, + ) + assert ( + async_save_request is None + ), "Async save is not yet supported for Megatron" + + dist.barrier() + ckpt_base.async_calls.close() + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + self.print(f"Checkpoint successfully saved to {ckpt_dir}") + + def load_checkpoint( + self, + ckpt_dir: str, + load_module_strict: bool = True, + load_optimizer_states: bool = True, + load_lr_scheduler_states: bool = True, + ): + if not ckpt_dir or not io.exists(ckpt_dir): + raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") + + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + unwrapped_model = model[0] + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + + # Extract sharded state dicts. + sharded_state_dict = {} + model_sharded_state_dict = unwrapped_model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer and load_optimizer_states: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler and load_lr_scheduler_states: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory + # this should be improved to download only the relevant shards for this actor to load + with io.local_read_dir(ckpt_dir) as read_dir: + # Load the checkpoint in parallel. + load_strategy = get_default_load_sharded_strategy(read_dir) + load_strategy = FullyParallelLoadStrategyWrapper( + load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=read_dir, + sharded_strategy=load_strategy, + ) + + # Load the model, optimizer, and scheduler state dicts. + assert ( + "model" in state_dict + ), f"Model state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + model[0].load_state_dict(state_dict["model"], strict=load_module_strict) + self.print("Loaded model state dict.") + + if optimizer and load_optimizer_states: + assert ( + "optimizer" in state_dict + ), f"Optimizer state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + optimizer.load_state_dict(state_dict["optimizer"]) + self.print("Loaded optimizer state dict.") + + if scheduler and load_lr_scheduler_states: + assert ( + "lr_scheduler" in state_dict + ), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + scheduler.load_state_dict(state_dict["lr_scheduler"]) + self.print("Loaded LR scheduler state dict.") + + # Load RNG state, if present. + if "rng" in state_dict: + self.load_rng_state(state_dict["rng"]) + + return ckpt_dir, {} + + def offload_to_cpu(self): + self.all_buffer_sizes = offload_megatron_grads_to_cpu(self.actor_module) + offload_megatron_model_to_cpu(self.actor_module) + offload_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def backload_to_gpu(self): + load_megatron_grads_to_gpu(self.actor_module) + load_megatron_model_to_gpu(self.actor_module) + load_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # model init and bridge from huggingface methods: + def init_configs( + self, + model_path, + megatron_config, + transformer_config, + bf16=True, + flash_attn=True, + ): + """ + Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer + """ + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend + transformer_config = asdict(transformer_config) + transformer_config["attention_backend"] = "flash" if flash_attn else "fused" + + bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) + provider = bridge.to_megatron_provider() + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = ( + megatron_config.pipeline_model_parallel_size + ) + provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 + provider.context_parallel_size = megatron_config.context_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = ( + megatron_config.expert_tensor_parallel_size + ) + provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 + provider.attention_backend = "flash" if flash_attn else "fused" + provider.variable_seq_lengths = True + provider.masked_softmax_fusion = True + provider.moe_token_dispatcher_type = "alltoall" + + for k, v in transformer_config.items(): + setattr(provider, k, v) + provider.finalize() + + self.provider = provider + self.bridge = bridge + self.tokenizer = tokenizer + + def make_megatron_module( + self, + wrap_with_ddp: bool = True, + ddp_config: Optional[Dict[str, Any]] = None, + bf16: bool = True, + ) -> List[nn.Module]: + """ + Creates a megatron GPTModel (optionally DDP wrapped) using the bridge. + """ + from megatron.core.distributed.distributed_data_parallel_config import ( + DistributedDataParallelConfig, + ) + + default_ddp_config = DistributedDataParallelConfig() + if wrap_with_ddp: + default_ddp_config.use_distributed_optimizer = True + if ddp_config is not None: + for k, v in ddp_config.items(): + setattr(default_ddp_config, k, v) + model = self.provider.provide_distributed_model( + ddp_config=default_ddp_config, wrap_with_ddp=wrap_with_ddp, bf16=bf16 + ) + return model + + def init_model(self, model_path, num_training_steps: int = 1e9): + """ + Initialize the model, optimizer, and scheduler for the policy worker. + """ + # Store model path for potential recovery + self._model_path = model_path + self._num_training_steps = num_training_steps + + # initialize the bridge and provider objects + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # wrap with DDP for training + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + if self._local_rank == 0 and not os.path.exists( + model_path + ): # if not local path, try downloading model weights from huggingface + snapshot_download(model_path) # will be no-op if already downloaded + torch.distributed.barrier() + + # create optimizer + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + self._normalize_mini_batch_size() + + # create scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # create worker model + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # NOTE: Set Megatron dist checkpoint async backend to persistent to avoid `os.fork()`-ing + # short-lived background workers, which does not work well with Ray. + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + + +class MegatronActorGroup: + """ + A group of distributed megatron actors + Functions start with 'async' should return list of object refs + + Args: + cfg: config object for workers + num_nodes (int): Number of nodes for this actor group. + num_gpus_per_node (int): Number of gpus for this actor group. + pg (PlacementGroup, optional): Placement group to schedule actor on. + If none, create new placement group automatically. Defaults to None. + num_gpus_per_actor (float, optional): Number of gpus allocated for each actor. + If < 1.0, multiple models can share same gpu. Defaults to 1. + """ + + def __init__( + self, + cfg, + num_nodes, + num_gpus_per_node, + pg: PlacementGroup, + bundle_indices: List[int], + num_gpus_per_actor: float = 1.0, + resources: Optional[Dict[str, float]] = None, + num_resources_per_node: Optional[int] = None, + ) -> None: + self.cfg = cfg + self._num_nodes = num_nodes + self._num_gpus_per_node = num_gpus_per_node + + # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html + self._resources = resources + self._num_resources_per_node = num_resources_per_node + + self._initiate_actors(pg, num_gpus_per_actor, bundle_indices) + + def _initiate_actors( + self, + pg: Optional[PlacementGroup], + num_gpus_per_actor: float, + bundle_indices: List[int], + ): + """Initialize Ray actors in the worker group. + + Args: + pg: The placement group for the worker group + num_gpus_per_actor: The number of gpus to allocate per actor. + """ + world_size = self._num_nodes * self._num_gpus_per_node + assert pg is not None, "placement group must be provided to MegatronActorGroup" + pg_data = placement_group_table(pg) + assert ( + len(pg_data["bundles"]) >= world_size + ), "the number of bundles in the shared placement group must be greater than or equal to the world size" + + # place master actor on the + master_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[0], + ), + ).remote( + world_size=world_size, + rank=0, + local_rank=0, + master_addr=None, + master_port=None, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + + self._actor_handlers = [master_actor] + # Create worker actors + if world_size > 1: + master_addr, master_port = ray.get( + master_actor.get_master_addr_port.remote() + ) + for rank in range(1, world_size): + local_rank = rank % self._num_gpus_per_node + + worker_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[rank], + ), + ).remote( + world_size=world_size, + rank=rank, + local_rank=local_rank, + master_addr=master_addr, + master_port=master_port, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + self._actor_handlers.append(worker_actor) + + def initiate_worker_process_group(self): + # Initialize process group + logger.info("Initializing process group for RayActorGroup") + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + logger.info("Initialized process group for RayActorGroup") + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) + + def async_init_model( + self, + *args, + **kwargs, + ) -> List[ObjectRef]: + """Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers. + + Returns: + A list of ray object refs. + """ + return [ + actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers + ] + + def async_run_ray_method( + self, dispatch_type: str, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors using specified dispatch type asynchronously. + + Args: + dispatch_type: Type of dispatch to use ("mesh" or "pass_through") + method_name: Name of the method to call on actors + *args: Positional arguments to pass to the method + **kwargs: Keyword arguments to pass to the method + + Returns: + List of object references + """ + dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type) + # validate the dispatch args to be sent to `.dispatch` + args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs) + + # Dispatch the method call + object_refs = dispatch_class.dispatch( + self.actor_infos, method_name, *args, **kwargs + ) + return object_refs + + def async_run_method_no_dispatch( + self, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors without dispatching.""" + return [ + getattr(handle, method_name).remote(*args, **kwargs) + for handle in self._actor_handlers + ] + + def _check_actor_alive(self, actor_handle) -> bool: + """Check if an actor is still alive by attempting to call a simple method.""" + try: + # Try to get a simple attribute or call a simple method with timeout + ray.get(actor_handle.get_mesh_rank.remote(), timeout=10) + return True + except Exception: + return False + + def recover_from_failure( + self, backup_actor_group: Optional["MegatronActorGroup"] = None + ): + """Recover from actor failures by removing dead actors and re-initializing process group.""" + logger.info("Starting recovery from actor failure...") + + # Filter out dead actors - both actor_infos and actor_handlers should be in sync + alive_actor_handlers = [] + num_dead_actors = 0 + dead_actor_ranks = [] + + for i, (actor_info, actor_handle) in enumerate( + zip(self.actor_infos, self._actor_handlers) + ): + if self._check_actor_alive(actor_info.handle): + alive_actor_handlers.append(actor_handle) + else: + logger.warning(f"Actor {i} is dead, removing from group") + num_dead_actors += 1 + dead_actor_ranks.append(i) + + if len(alive_actor_handlers) == 0: + raise RuntimeError("All actors are dead, cannot recover") + + if len(alive_actor_handlers) == len(self._actor_handlers): + logger.info("All actors are alive, no recovery needed") + return + + logger.info( + f"Recovering with {len(alive_actor_handlers)}/{len(self._actor_handlers)} actors" + ) + + self._actor_handlers = alive_actor_handlers + + # Destroy existing process groups on alive actors first + logger.info("Destroying old process groups...") + try: + ray.get( + [ + actor.destroy_worker_process_group.remote() + for actor in self._actor_handlers + ] + ) + except Exception as e: + logger.warning( + f"Some errors during process group destruction (may be expected): {e}" + ) + + # if backup actor group is provided, we pop idle actors from the backup actor group and insert them into the current actor group + if backup_actor_group is not None: + logger.info( + f"Popping {num_dead_actors} idle actors from backup actor group" + ) + idle_actor_handles = [ + backup_actor_group._actor_handlers.pop() for _ in range(num_dead_actors) + ] + # let's assume for now that the dead actors are contiguous in the actor group, so we insert the idle actors at the rank of the first dead actor + rank_to_insert = min(dead_actor_ranks) + logger.info(f"Inserting idle actors at rank {rank_to_insert}") + self._actor_handlers = ( + self._actor_handlers[:rank_to_insert] + + idle_actor_handles + + self._actor_handlers[rank_to_insert:] + ) + + # Re-initialize process group with remaining actors + # Update world_size and ranks to match the number of alive actors + new_world_size = len(self._actor_handlers) + + # Update world_size and reassign ranks sequentially (0, 1, 2, ...) + logger.info(f"Updating world_size to {new_world_size} and reassigning ranks...") + update_tasks = [] + for new_rank, actor in enumerate(self._actor_handlers): + update_tasks.append(actor.update_world_size.remote(new_world_size)) + update_tasks.append(actor.update_rank.remote(new_rank)) + ray.get(update_tasks) + + # get master address and a new free port for the new process group + master_addr, _ = ray.get(self._actor_handlers[0].get_master_addr_port.remote()) + master_port = ray.get(self._actor_handlers[0]._get_free_port.remote()) + logger.info(f"Using master_addr={master_addr}, master_port={master_port}") + + # Update master address/port in all actors + ray.get( + [ + actor.update_master_addr_port.remote(master_addr, master_port) + for actor in self._actor_handlers + ] + ) + + # Re-initialize process groups with new world_size and ranks + logger.info( + f"Re-initializing process group with world_size={new_world_size}..." + ) + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + + # Re-initialize model and optimizer with the new process group + # This is critical because they were created with the old process group + logger.info("Re-initializing model and optimizer with new process group...") + ray.get( + [ + actor.reinit_model_after_recovery.remote() + for actor in self._actor_handlers + ] + ) + + # Update actor_infos with new mesh ranks + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Recovery complete. New mesh ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) diff --git a/megatron_ray_fault_tolerant/megatron_model_utils.py b/megatron_ray_fault_tolerant/megatron_model_utils.py new file mode 100644 index 0000000..bc5be4a --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_utils.py @@ -0,0 +1,442 @@ +# Utils ported from NeMo-Aligner by way of NeMo-RL +# https://github.com/NVIDIA-NeMo/RL/blob/9301d36cbf847212430b84a27cfe6990f773b7cf/nemo_rl/distributed/model_utils.py#L4 +# The original copyright is reproduced below: + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + + +@torch.no_grad() +def _compute_distributed_log_softmax( + vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup +) -> torch.Tensor: + """Compute a stable distributed log softmax across tensor parallel workers. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP] + where TP is the tensor parallel size. + group (torch.distributed.ProcessGroup): Process group for the all-reduce operations. + + Returns: + torch.Tensor: Log softmax output with the same shape as input, but values represent + log probabilities normalized across the full vocabulary dimension. + """ + logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max + + sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() + + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) + + +class DistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) + softmax_output = log_probs.exp() + + log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) + log_probs[target_mask] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(softmax_output, target_mask, masked_target) + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + softmax, target_mask, masked_target = ctx.saved_tensors + + if softmax.ndim == 3: + B, S, V = softmax.shape + + # skip `torch.nn.functional.one_hot` + row = ( + torch.arange(B, device=softmax.device) + .view(-1, 1) + .expand(-1, S) + .reshape(-1) + ) + col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) + flat_idx = (row * S + col) * V + + flat_chosen = flat_idx.masked_select( + ~target_mask.reshape(-1) + ) + masked_target.masked_select(~target_mask) + + # `neg` is zero-copy + grad_input = softmax.neg() + grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) + + grad_output_selected = grad_output.masked_select(~target_mask) + grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + else: + V = softmax.size(-1) + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V + ) + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +class ChunkedDistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + The log probabilities computation is chunked in the sequence dimension + to mitigate GPU OOM (especially during backward pass). + In addition, logits casting from float16 or bfloat16 -> float32 is performed + inside the chunk loop to avoid materializing a whole float32 logits tensor. + + Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + + log_probs = torch.gather( + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(vocab_parallel_logits, target_mask, masked_target) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + vocab_parallel_logits, target_mask, masked_target = ctx.saved_tensors + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + + partition_vocab_size = int(vocab_parallel_logits.shape[-1]) + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + + all_grad_input = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + softmax_output = softmax_output.exp() + + # 1 if it's the chosen log prob, 0 otherwise + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + + grad_input = is_chosen.float().sub_(softmax_output) + + grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + + all_grad_input.append(grad_input) + + grad_input = torch.cat(all_grad_input, dim=1) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +def from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, +) -> torch.Tensor: + """Get log probabilities from TP+CP sharded vocab logits. + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] + where TP is the tensor parallel size. + target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. + NOTE: Must be the unmodified targets as this function will shift them internally. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + + Returns: + torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. + The sequence dimension is reduced by 1 due to the target shifting. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + """ + target = target.roll(shifts=-1, dims=-1) + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + pad_len = 0 + # if cp_size > 1: + # Pad the targets to local size * cp_size + pad_len = vocab_parallel_logits.shape[1] * cp_size - target.shape[1] + if pad_len > 0: + target = torch.nn.functional.pad(target, (0, pad_len), value=0) + + # Shard the targets by context parallelism + cp_rank = torch.distributed.get_rank(cp_group) + target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) + + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() + + if cp_size > 1: + # we need to gather the logits by context parallelism + logprobs = allgather_cp_sharded_tensor( + logprobs, cp_group, seq_dim=1 + ) # , unpadded_seqlen=target.shape[1]) + + if pad_len > 0: + logprobs = logprobs[:, :-pad_len] + + return logprobs[:, :-1] + + +def _get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1, +) -> torch.Tensor: + """Get tokens on this context parallelism rank. + + Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + + Args: + input_ids: Input token IDs [seq_length, ] + cp_rank: Context parallelism rank + cp_size: Context parallelism size + + Returns: + Tokens on this context parallelism rank [1, seq_length // cp_size] + """ + if cp_size == 1: + return input_ids + + # load balance for causal attention + shard_size = input_ids.shape[seq_dim] // (cp_size * 2) + shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1) + + # Create slices for each dimension + slices = [slice(None)] * input_ids.dim() + ids_chunks = [] + + for ind in shard_inds: + slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) + ids_chunks.append(input_ids[slices]) + + ids = torch.cat(ids_chunks, dim=seq_dim) + return ids + + +def allgather_cp_sharded_tensor( + tensor, cp_group, seq_dim=1 +): # , unpadded_seqlen=None): + return AllGatherCPTensor.apply(tensor, cp_group, seq_dim) # , unpadded_seqlen) + + +class AllGatherCPTensor(torch.autograd.Function): + def forward( + ctx, tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1 + ): # , unpadded_seqlen: Optional[int] = None): + cp_size = torch.distributed.get_world_size(cp_group) + cp_rank_chunks = [] + for _ in range(cp_size): + cp_rank_chunks.append(torch.empty_like(tensor)) + + torch.distributed.all_gather( + tensor_list=cp_rank_chunks, tensor=tensor, group=cp_group + ) + + # undo the CP load balancing chunking + tensor_chunks = [] + for logit_chunk in cp_rank_chunks: + tensor_chunks.extend(torch.chunk(logit_chunk, chunks=2, dim=seq_dim)) + + chunk_indices = [] + for cp_rank in range(cp_size): + chunk_indices.append(cp_rank) + chunk_indices.append(2 * cp_size - cp_rank - 1) + + chunks_and_indices = list(zip(tensor_chunks, chunk_indices)) + chunks_and_indices = sorted(chunks_and_indices, key=lambda tup: tup[1]) + ret_tensor = [chunk for chunk, _ in chunks_and_indices] + ret_tensor = torch.cat(ret_tensor, dim=seq_dim) + + ctx.seq_dim = seq_dim + ctx.cp_group = cp_group + # ctx.unpadded_seqlen = unpadded_seqlen + + return ret_tensor + + def backward(ctx, grad_output): + cp_size = torch.distributed.get_world_size(ctx.cp_group) + cp_rank = torch.distributed.get_rank(ctx.cp_group) + torch.distributed.all_reduce(grad_output, group=ctx.cp_group) + + # chunk the seqdim in 2*cp chunks, and select with a CP load balanced indexing + seq_dim = ctx.seq_dim + # if ctx.unpadded_seqlen is not None: + # # Zero out grad_output along the seq_dim after unpadded_seqlen + # slicer = [slice(None)] * grad_output.dim() + # slicer[seq_dim] = slice(ctx.unpadded_seqlen, None) + # grad_output[tuple(slicer)] = 0 + + grad_output = grad_output.view( + *grad_output.shape[0:seq_dim], + 2 * cp_size, + grad_output.shape[seq_dim] // (2 * cp_size), + *grad_output.shape[(seq_dim + 1) :], + ) + + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + + grad_input = grad_output.index_select(seq_dim, index) + grad_input = grad_input.view( + *grad_input.shape[0:seq_dim], -1, *grad_input.shape[(seq_dim + 2) :] + ) + + return grad_input, None, None # , None diff --git a/megatron_ray_fault_tolerant/megatron_model_wrapper.py b/megatron_ray_fault_tolerant/megatron_model_wrapper.py new file mode 100644 index 0000000..07e885d --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_wrapper.py @@ -0,0 +1,171 @@ +from typing import Optional, List +from functools import partial +import torch +import torch.nn as nn + +from megatron.core.pipeline_parallel import get_forward_backward_func +import megatron.core.parallel_state as mpu +from megatron.core.distributed import finalize_model_grads + +from megatron_model_utils import from_parallel_logits_to_logprobs +from megatron_utils import ( + get_model_config, + make_batch_generator, + preprocess_packed_seqs, + postprocess_packed_seqs, +) +from utils import ppo_policy_loss + + +class MegatronModelWrapper: + def __init__( + self, + config, + actor_module: List[nn.Module], + actor_optimizer: Optional[torch.optim.Optimizer] = None, + ): + self.cfg = config + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + + config = get_model_config(self.actor_module[0]) + # This is set to None by default: https://github.com/NVIDIA/Megatron-LM/blob/07b22a05136a3cb08ece05f7de38cf6aeeb165fb/megatron/core/model_parallel_config.py#L95 + # use the build in finalize_model_grads function to all reduce gradients across parallelism dimensions + config.finalize_model_grads_func = finalize_model_grads + + def train(self): + [module.train() for module in self.actor_module] + + def eval(self): + [module.eval() for module in self.actor_module] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward_backward_mini_batch( + self, + micro_batches: List[dict], + seq_len: int, + micro_batch_size: int, + temperature: float = 1.0, + ) -> List[dict]: + """ + Run forward-backward over a full mini-batch consisting of multiple micro-batches. + + Args: + micro_batches: A list of micro-batch dicts. Each dict must contain keys: + "sequences", "attention_mask", "position_ids", "num_actions", + "old_action_log_probs", "base_action_log_probs", "advantages", + "loss_mask". + seq_len: Sequence length (tokens) per sample (assumed same across micros after padding). + micro_batch_size: Micro-batch size per forward pass. + temperature: Optional temperature for logits scaling. + + Returns: + List[dict]: one metrics dict per micro-batch in order. + """ + forward_backward_func = get_forward_backward_func() + + def loss_func(logits, data): + sequences = data["sequences"] + num_actions = data["num_actions"] + old_action_log_probs = data["old_action_log_probs"] + advantages = data["advantages"] + loss_mask = data["loss_mask"] + + tp_grp = mpu.get_tensor_model_parallel_group() + tp_rank = mpu.get_tensor_model_parallel_rank() + + # temperature normalization + if temperature != 1.0: + logits.div_(temperature) + + token_logprobs = from_parallel_logits_to_logprobs( + logits, + sequences, + vocab_start_index=tp_rank * logits.shape[-1], + vocab_end_index=(tp_rank + 1) * logits.shape[-1], + tp_group=tp_grp, + inference_only=False, + cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` + chunk_size=None, + ) + + action_log_probs = token_logprobs[:, -num_actions:] + + # policy loss should be calculated based on the selected token logprobs + policy_loss, clip_ratio = ppo_policy_loss( + action_log_probs, + old_action_log_probs, + advantages, + config=self.cfg, + loss_mask=loss_mask, + ) + + # no kl loss or entropy loss + loss = policy_loss + + metrics = { + "policy_loss": policy_loss.detach().item(), + "ppo_clip_ratio": clip_ratio, + } + return loss, metrics + + def forward_step(batch_iter, model): + batch = next(batch_iter) + + sequences = batch["sequences"] + attention_mask = batch["attention_mask"].to(bool) + + new_sequences, packed_seq_params = preprocess_packed_seqs( + sequences, + attention_mask, + pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True), + ) + new_attention_mask = None + new_position_ids = None + + outputs = model( + new_sequences, + new_position_ids, + new_attention_mask, + packed_seq_params=packed_seq_params, + ) + + outputs = postprocess_packed_seqs( + outputs, + packed_seq_params, + attention_mask, + micro_batch_size, + seq_len, + post_process=mpu.is_pipeline_last_stage(ignore_virtual=True), + ) + + return outputs, partial(loss_func, data=batch) + + # batch should be a list of micro-batches + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.actor_module) + ) + + metrics_list = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=len(micro_batches), + seq_length=seq_len, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + # broadcast metrics to all pp ranks + if not mpu.is_pipeline_last_stage(ignore_virtual=True): + metrics_list = [None] * len(micro_batches) + with torch.no_grad(): + torch.distributed.broadcast_object_list( + metrics_list, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + + return metrics_list diff --git a/megatron_ray_fault_tolerant/megatron_utils.py b/megatron_ray_fault_tolerant/megatron_utils.py new file mode 100644 index 0000000..4a0f015 --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_utils.py @@ -0,0 +1,465 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron_utils.py#L4 +# https://github.com/volcengine/verl/blob/dfa3933ac44b545fca1f6a8519fd07394a2cde1c/verl/models/mcore/util.py +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import gc +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.transformer.module import Float16Module +from megatron.core.optimizer import ChainedOptimizer +from megatron.core import parallel_state as mpu +from megatron.core.utils import get_attr_wrapped_model +from megatron.core.packed_seq_params import PackedSeqParams + +ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def make_batch_generator(batches, vpp_size): + """ + Creates a batch generator suitable for Megatron pipeline parallelism, + handling virtual pipeline parallelism (VPP). + + If VPP is used (vpp_size > 1), it duplicates the batch iterator for each + virtual pipeline stage. Otherwise, it returns a single iterator. + + Args: + batches: An iterable (e.g., list) of micro-batches. + vpp_size (int): The virtual pipeline model parallel size. + + Returns: + An iterator or a list of iterators over the micro-batches. + """ + if vpp_size > 1: + # has vpp + batch_generator = [batches] * vpp_size # number of vpp chunks + batch_generator = [iter(b) for b in batch_generator] + else: + # no vpp + batch_generator = iter(batches) + return batch_generator + + +@torch.no_grad() +def offload_megatron_grads_to_cpu(models): + all_buffer_sizes = [] + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + buffer_sizes = [] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.grad_data.storage().size() > 0: + buffer_sizes.append(buffer.grad_data.storage().size()) + buffer.grad_data.storage().resize_(0) + all_buffer_sizes.append(buffer_sizes) + else: + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + return all_buffer_sizes + + +@torch.no_grad() +def load_megatron_grads_to_gpu(models, buffer_sizes): + for i, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for j, buffers in enumerate(model_chunk_all_buffers): + for buffer in buffers: + buffer.grad_data.storage().resize_(buffer_sizes[i][j]) + buffer.grad_data.zero_() + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = ( + buffer.param_data.data.cpu().pin_memory() + ) + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert ( + buffer.param_data_size + == buffer.param_data.cpu_data.storage().size() + ) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_model_to_gpu(models): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_( + buffer.param_data.cpu_data, non_blocking=True + ) + else: + # we need this for ref module + device_id = torch.cuda.current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to("cpu", non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = torch.cuda.current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, "_move_new_state_to_right_device"): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to( + torch.cuda.current_device(), non_blocking=True + ) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = ( + seqlens_in_batch.tolist() + ) # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = ( + seqlens_in_batch_padded.tolist() + ) # lengths after padding + cu_seqlens_padded_cpu: list[int] = ( + cu_seqlens_padded.tolist() + ) # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros( + shape, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[ + i, attention_mask[i] + ] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[ + start_idx + half_seqlen : start_idx + half_seqlen + remain_len + ] = d[remain_start:remain_end] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = ( + attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + ) + + shape = [batch_size, seq_len] + list( + output.shape[2:] + ) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather( + output_list, output.detach(), group=mpu.get_context_parallel_group() + ) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[ + packed_start_idx + + half_seqlen : packed_start_idx + + s_len_padded_chunk + ], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[ + s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen + ] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) diff --git a/megatron_ray_fault_tolerant/optimizer.py b/megatron_ray_fault_tolerant/optimizer.py new file mode 100644 index 0000000..f243397 --- /dev/null +++ b/megatron_ray_fault_tolerant/optimizer.py @@ -0,0 +1,103 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron/optimizer.py#L4 +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.optimizer import OptimizerConfig +from megatron.core.optimizer import ( + get_megatron_optimizer as get_megatron_optimizer_native, +) +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler + + +def init_megatron_optim_config(optim_config) -> OptimizerConfig: + optim_args = { + "optimizer": optim_config.get("optimizer", "adam"), + "lr": optim_config.get("lr", 1.0e-6), + "min_lr": optim_config.get("min_lr", 0.0), + "clip_grad": optim_config.get("max_grad_norm", 1.0), + "weight_decay": optim_config.get("weight_decay", 0.01), + "bf16": True, + "params_dtype": torch.bfloat16, + "use_distributed_optimizer": True, + } + + config = OptimizerConfig(**optim_args) + return config + + +def get_megatron_optimizer( + model, + config: OptimizerConfig, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, +): + # Base optimizer. + return get_megatron_optimizer_native( + config=config, + model_chunks=model, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + ) + + +def get_megatron_optimizer_param_scheduler( + optimizer, + config, + num_training_steps: int = 1e9, # default to a large number for constant lr/wd +): + """ + Get the optimizer parameter scheduler for Megatron. + """ + lr_warmup_steps = config.get("num_warmup_steps", 0) + if config.get("lr_decay_steps", None) is None: + lr_decay_steps = num_training_steps + if config.get("lr_warmup_steps_ratio", None) is not None and ( + config.get("lr_warmup_steps", None) is None + or config.get("lr_warmup_steps", 0) <= 0 + ): + lr_warmup_steps = int(config.get("lr_warmup_steps_ratio", 0.0) * lr_decay_steps) + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=config.get("lr_warmup_init", 0.0), + max_lr=config.get("lr", 1.0e-6), + min_lr=config.get("min_lr", 0.0), + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style="constant", + start_wd=config.get("weight_decay", 0.01), + end_wd=config.get("weight_decay", 0.01), + wd_incr_steps=num_training_steps, + wd_incr_style="constant", + use_checkpoint_opt_param_scheduler=False, + override_opt_param_scheduler=True, + wsd_decay_steps=None, + lr_wsd_decay_style="exponential", + ) + + return opt_param_scheduler + + +def get_megatron_last_lr(optimizer): + """ + Get the last learning rate from the optimizer parameter scheduler. + """ + return optimizer.param_groups[0]["lr"] diff --git a/megatron_ray_fault_tolerant/pyproject.toml b/megatron_ray_fault_tolerant/pyproject.toml new file mode 100644 index 0000000..51be3c3 --- /dev/null +++ b/megatron_ray_fault_tolerant/pyproject.toml @@ -0,0 +1,98 @@ +[project] +name = "ray-ft" +version = "0.0.1" +description = "ray" +authors = [ + {name = "ray", email = "ray@gmail.com"} +] +license = {text = "MIT"} +readme = "README.md" +requires-python = "==3.12.*" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +dependencies = [ + "ninja", + "tensorboard", + "func_timeout", + "transformers>=4.51.0", + "torchdata", + "omegaconf", + "ray==2.51.0", + "peft", + "debugpy==1.8.0", + "hf_transfer", + "wandb", + "datasets==4.0.0", + "flash-attn", + "polars", + "loguru", + "jaxtyping", + "s3fs", + # Make sure to change the flash attention source (under tool.uv.sources) above to a compatible version (<= 2.7.4.post1) for TransformerEngine==2.5.0 + # https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl + # For single node: build transformer-engine separately first, and uncomment the transformer-engine library import below + # uv pip install "torch==2.7.1" + # uv pip install "nvidia-cudnn-cu12>=9.3" + # export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" + # export CPATH="$CUDNN_PATH/include:${CPATH:-}" + # export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}" + # uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose + # "transformer-engine[pytorch]==2.5.0", + "transformer-engine[pytorch]==2.7.0", + "flash-attn==2.7.4.post1", + "vllm==0.10.1.1", + "torch==2.7.1", + "flashinfer-python", + "torchvision", + "megatron-bridge==0.1.0rc4", + "megatron-core==0.14.0", +] + +[tool.uv] +required-version = ">=0.8.10" +no-build-isolation-package = [ + "transformer-engine-torch", + "transformer-engine", +] + +[tool.uv.extra-build-dependencies] +flash-attn = [{requirement = "torch", match-runtime = true}] +transformer-engine = [{ requirement = "torch", match-runtime = true }, "build_tools"] +transformer-engine-torch = [{ requirement = "torch", match-runtime = true }, "build_tools"] + +[tool.uv.extra-build-variables] +flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} + +[tool.uv.sources] +torch = { index = "pytorch-cu128" } +torchvision = { index = "pytorch-cu128" } +# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run. +# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang +flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" } +flashinfer-python = [ + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" }, + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" } +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[[tool.uv.index]] +name = "flashinfer-cu128" +url = "https://flashinfer.ai/whl/cu128" +explicit = true + +[tool.setuptools] +include-package-data = true + +[tool.pytest.ini_options] +addopts = "-v -s" +testpaths = [ + "tests", +] \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/run.sh b/megatron_ray_fault_tolerant/run.sh new file mode 100755 index 0000000..c9455a3 --- /dev/null +++ b/megatron_ray_fault_tolerant/run.sh @@ -0,0 +1 @@ +anyscale job submit -f job.yaml \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/training_batch.py b/megatron_ray_fault_tolerant/training_batch.py new file mode 100644 index 0000000..eacdbe6 --- /dev/null +++ b/megatron_ray_fault_tolerant/training_batch.py @@ -0,0 +1,371 @@ +"""Defines interfaces for training data.""" + +from typing import TypedDict, Dict, Any, List, Optional, Generic, TypeVar +import torch +from jaxtyping import Float, Integer +import pickle +import io + +DictType = TypeVar("DictType") + + +# Class inspired by `TensorDict` but is much simpler. +class TensorBatch(dict, Generic[DictType]): + """Base class for training batches + + This defines a generic container for a batch of training data (inputs or outputs). + Consists of a dictionary of tensors along with some metadata. + """ + + metadata: Optional[Dict[str, Any]] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._batch_size = None + self._device = None + self._check_consistency() + + def select( + self, keys: List[str], metadata_keys: Optional[List[str]] = None + ) -> "TensorBatch[DictType]": + """Select a subset of the data batch. + + Args: + keys: The keys to select + metadata_keys: The metadata keys to select + + Returns: + A new `TensorBatch` object with the selected keys and metadata + """ + selected_batch_data = {} + for key in keys: + selected_batch_data[key] = self[key] + selected_metadata = {} + if metadata_keys is None: + selected_metadata = self.metadata + else: + selected_metadata = {} + for key in metadata_keys: + selected_metadata[key] = self.metadata[key] + new_batch = self.__class__(selected_batch_data) + new_batch.metadata = selected_metadata + return new_batch + + def _check_consistency(self): + """Check consistency of all present fields""" + keys = list(self.keys()) + if len(keys) == 0: + return + + batch_size = len(self[keys[0]]) + self._batch_size = batch_size + for key in keys: + value = self[key] + if value is None: + continue + self._device = value.device if self._device is None else self._device + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + if len(value) != batch_size: + raise ValueError(f"Batch size mismatch in {key}") + if value.device != self._device: + raise ValueError( + f"Device mismatch in {key}. Expected {self._device}, got {value.device}" + ) + + def __getitem__(self, index) -> "TensorBatch[DictType]": + if isinstance(index, slice): + return self.slice(index.start, index.stop, index.step) + elif isinstance(index, int): + return self.slice(index, index + 1) + else: + return super().__getitem__(index) + + def __setitem__(self, key: str, value: Optional[torch.Tensor]) -> None: + if value is None: + super().__setitem__(key, value) + return + + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + + if ( + hasattr(self, "_batch_size") + and self._batch_size is not None + and len(value) != self._batch_size + ): + raise ValueError( + f"Batch size mismatch in {key}. Expected tensor to be of size {self._batch_size}, got {len(value)}." + ) + + super().__setitem__(key, value) + + if hasattr(self, "_batch_size") and self._batch_size is None: + self._batch_size = len(value) + + def to( + self, + device: torch.device = None, + dtype: torch.dtype = None, + *, + non_blocking: bool = False, + ) -> "TensorBatch": + """Move tensors to device and/or cast to dtype. + + Args: + device: The device to move the tensors to + dtype: The dtype to cast the tensors to + non_blocking: Whether the operation should be non-blocking + """ + for key, value in self.items(): + if value is None: + continue + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.to(device, dtype, non_blocking=non_blocking) + return self + + def contiguous(self) -> "TensorBatch": + """Make the tensors contiguous""" + for key, value in self.items(): + if value is None: + continue + # some of these asserts are not needed, but it's kept for type safety + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.contiguous() + return self + + @property + def batch_size(self) -> int: + """Batch size for the tensors""" + return self._batch_size + + @property + def device(self) -> torch.device: + """Get the device for the tensors""" + return self._device + + def __getstate__(self): + """Serialize the `TensorBatch` object for pickle protocol""" + self.contiguous() + if self._device is not None: + assert self._device == torch.device( + "cpu" + ), "Tensors must be on CPU before serialization" + batch_dict = {} + for key, value in self.items(): + buffer = io.BytesIO() + torch.save(value, buffer) + batch_dict[key] = buffer.getvalue() + + return { + "batch_dict": batch_dict, + "batch_size": self._batch_size, + "device": self._device, + "metadata": self.metadata, + } + + def __setstate__(self, state): + """Deserialize the `TensorBatch` object and load it into memory""" + for key, value in state["batch_dict"].items(): + buffer = io.BytesIO(value) + self[key] = torch.load(buffer) + + self._batch_size = state["batch_size"] + self._device = state["device"] + self.metadata = state["metadata"] + self._check_consistency() + return self + + def repeat(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def repeat_interleave(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat_interleave(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]: + """Split into smaller chunks""" + chunks = [] + for i in range(0, self.batch_size, chunk_size): + chunk_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + chunk_data[key] = value[i : i + chunk_size] + else: + raise ValueError( + f"Unsupported type {type(value)} for key {key}" + ) + else: + # `None` values are not chunked + chunk_data[key] = value + chunk = self.__class__(chunk_data) + chunk.metadata = self.metadata + chunks.append(chunk) + return chunks + + def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]": + """Slice the data batch. + + Args: + start: The start index + end: The end index + step: The step size + + Returns: + A new `TensorBatch` object with the view of the specified slice. + """ + slice_obj = slice(start, end, step) + sliced_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + sliced_data[key] = value[slice_obj] + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not sliced + sliced_data[key] = value + sliced_batch = self.__class__(sliced_data) + sliced_batch.metadata = self.metadata + return sliced_batch + + def save(self, path: str): + """Save the data to a pickle file""" + with open(path, "wb") as f: + pickle.dump(self, f) + + def load(self, path: str): + """Load the data from a pickle file""" + with open(path, "rb") as f: + return pickle.load(f) + + @classmethod + def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]": + """Concatenate shards. + + Args: + shards: The list of `TensorBatch` objects to cat + + Returns: + A new `TensorBatch` object with the concatenated data + """ + cat_data = {} + assert len(shards) > 0, "Cannot cat an empty list of shards" + for key, value in shards[0].items(): + if value is not None: + if isinstance(value, torch.Tensor): + cat_data[key] = torch.cat([shard[key] for shard in shards]) + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not cat'd + cat_data[key] = value + metadata = shards[0].metadata + cat_batch = cls(cat_data) + cat_batch.metadata = metadata + return cat_batch + + def __len__(self) -> int: + """Length of the batch. + + Note that this is the same as the batch size rather than the number of keys in the batch. + """ + return self._batch_size + + def __eq__(self, other: Any) -> bool: + """Check if two `TensorBatch` objects are equal""" + if not isinstance(other, TensorBatch): + return False + if self.metadata != other.metadata: + return False + if len(self) != len(other): + return False + if len(self.items()) != len(other.items()): + return False + for k, v in self.items(): + if k not in other or not torch.equal(v, other[k]): + return False + return True + + def __str__(self) -> str: + """String representation of the `TensorBatch` object""" + return f"TensorBatch(batch_size={self.batch_size}, device={self.device}, metadata={self.metadata}), items={self.items()}" + + def __repr__(self) -> str: + """String representation of the `TensorBatch` object""" + return self.__str__() + + +class TrainingInput(TypedDict, total=False): + """Schema for training input batch""" + + sequences: Integer[torch.Tensor, "batch_size seq_len"] + attention_mask: Integer[torch.Tensor, "batch_size seq_len"] + loss_mask: Integer[torch.Tensor, "batch_size seq_len"] + response_mask: Integer[torch.Tensor, "batch_size seq_len"] + action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + values: Optional[Float[torch.Tensor, "batch_size seq_len"]] + returns: Float[torch.Tensor, "batch_size seq_len"] + advantages: Float[torch.Tensor, "batch_size seq_len"] + kl: Float[torch.Tensor, "batch_size seq_len"] + rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + + +class TrainingInputBatch(TensorBatch[TrainingInput]): + """Training input data""" + + pass + + +class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]): + """Training output data""" + + pass diff --git a/megatron_ray_fault_tolerant/utils.py b/megatron_ray_fault_tolerant/utils.py new file mode 100644 index 0000000..e07689d --- /dev/null +++ b/megatron_ray_fault_tolerant/utils.py @@ -0,0 +1,286 @@ +import ray +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +import torch +from typing import Any, Optional, Dict, List, Union, Tuple +from dataclasses import dataclass +from jaxtyping import Integer, Float +import math +from transformers import AutoTokenizer + + +from training_batch import TrainingInputBatch + +BasicType = Union[int, float, str, bool] + + +@ray.remote(num_gpus=1) +class InfoActor: + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + +def get_reordered_bundle_indices(pg: PlacementGroup): + """ + Get the reordered bundle indices for a placement group to ensure adjacent ranks are on the same node when possible + """ + pg_data = placement_group_table(pg) + num_bundles = len(pg_data["bundles"]) + bundle_to_node_ids = pg_data["bundles_to_node_id"] + + # use info actor to get the GPU id + info_actors = [] + for i in range(num_bundles): + info_actors.append( + InfoActor.options( + num_cpus=0.01, # set both num_cpus and num_gpus to be small values to enable assignment in colocated case + num_gpus=0.01, + resources=None, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=i, + ), + ).remote() + ) + + gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors]) + for actor in info_actors: + ray.kill(actor) + + # original index, node_id, gpu_id + bundle_infos = [(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)] + pg_reordered_bundle_indices = [ + bundle_info[0] + for bundle_info in sorted(bundle_infos, key=lambda x: (x[1], x[2])) + ] # sort by node_id, then gpu_id + return pg_reordered_bundle_indices + + +def to(tensor: Union[torch.Tensor, List[torch.Tensor], BasicType], device): + if isinstance(tensor, list): + return [to(t, device) for t in tensor] + elif isinstance(tensor, torch.Tensor): + return tensor.to(device) + else: + return tensor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + base_action_log_probs: (B, A) + values: (B, A) + returns: (B, A) + advatanges: (B, A) + attention_mask: (B, S) + action_mask: (B, A) + kl: (B, A) + + "A" is the number of actions/ response length. + """ + + sequences: Integer[torch.Tensor, "batch seq_len"] + action_log_probs: Float[torch.Tensor, "batch response_len"] + base_action_log_probs: Optional[Float[torch.Tensor, "batch response_len"]] + values: Optional[Float[torch.Tensor, "batch response_len"]] + returns: Optional[Float[torch.Tensor, "batch response_len"]] + advantages: Optional[Float[torch.Tensor, "batch response_len"]] + attention_mask: Optional[Integer[torch.LongTensor, "batch seq_len"]] + loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] + action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + num_actions: int + info: Optional[dict] + kl: Optional[Float[torch.Tensor, "batch response_len"]] = None + metadata: Optional[Dict[str, Any]] = None + + @torch.no_grad() + def to_device(self, device: torch.device) -> None: + self.sequences = to(self.sequences, device) + self.action_log_probs = to(self.action_log_probs, device) + if self.base_action_log_probs is not None: + self.base_action_log_probs = to(self.base_action_log_probs, device) + if self.values is not None: + self.values = to(self.values, device) + if self.returns is not None: + self.returns = to(self.returns, device) + if self.advantages is not None: + self.advantages = to(self.advantages, device) + if self.attention_mask is not None: + self.attention_mask = to(self.attention_mask, device) + if self.loss_mask is not None: + self.loss_mask = to(self.loss_mask, device) + if self.action_mask is not None: + self.action_mask = to(self.action_mask, device) + if self.rollout_logprobs is not None: + self.rollout_logprobs = to(self.rollout_logprobs, device) + + +class BatchIterator: + """A simple iterator to yield micro batches of data from the training batch.""" + + def __init__( + self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False + ): + self.data = data + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = ( + int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + ) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> Experience: + try: + batch = next(self._iter) + exp = self.batch_to_experience(batch) + return exp + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + @staticmethod + def batch_to_experience(batch: TrainingInputBatch): + exp = Experience( + sequences=batch["sequences"], + action_log_probs=batch["action_log_probs"], + base_action_log_probs=batch["base_action_log_probs"], + values=batch["values"], + returns=batch["returns"], + advantages=batch["advantages"], + attention_mask=batch["attention_mask"], + loss_mask=batch["loss_mask"], + action_mask=batch["response_mask"], + num_actions=batch.metadata["response_length"], # int + rollout_logprobs=( + batch["rollout_logprobs"] if "rollout_logprobs" in batch else None + ), + # additional info + # can be used to log metrics etc for micro-batches in the worker + info={}, + # propagate metadata as is + metadata=batch.metadata, + ) + return exp + + +def masked_mean( + tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: Optional[int] = None +) -> torch.Tensor: + if mask is None: + return tensor.mean(axis=dim) + return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim).clamp(min=1.0) + + +def _safe_exp_delta( + delta: torch.Tensor, clip: float = 20.0, out_dtype=None +) -> torch.Tensor: + """ + Clamp the delta before exponentiating to avoid potential overflow. + """ + y = torch.clamp(delta.to(torch.float32), -clip, clip).exp() + return y.to(out_dtype or delta.dtype) + + +def ppo_policy_loss( + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + config, + loss_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, float]: + """Compute dual clip PPO policy loss.""" + ratio = _safe_exp_delta( + log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype + ) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages + loss = -torch.min(surr1, surr2) + clip_ratio = ( + masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() + ) + clip_pg_losses1 = loss + pg_losses3 = -advantages * config.clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + loss = loss = masked_mean(loss, loss_mask) + return loss, clip_ratio + + +def get_test_training_batch(model_name, batch_size=4) -> TrainingInputBatch: + """ + Returns a test training batch with padded seqs and attention masks + + Gives a batch of 4 sequences with variable amounts of left padding, and variable response lengths/amounts of right padding + Attention masks are 1 for non-padding tokens, 0 for padding tokens + The rest of the fields are filled with dummy data + """ + assert batch_size % 4 == 0, "batch size must be divisible by 4" + num_repeats = batch_size // 4 + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + sentences = [ + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.", + "<|im_start|>user\nThe selling price of a bicycle that had sold $220 last year was increased by 15", + "What is the new price? Let's think step by step and output the final answer after `####`.<|im_end|>\n", + "<|im_start|>assistant\nTo find the new price of the bicycle after the increase,", + ] * num_repeats + + sequences = [tokenizer.encode(sentence) for sentence in sentences] + attention_masks = [[1] * len(seq) for seq in sequences] + num_actions = 10 + # max seq len 1 longer than the longest sequence so we always have some padding + max_seq_length = max([len(seq) for seq in sequences]) + 7 + + pad_token_id = tokenizer.pad_token_id + pad_before = [4, 0, 1, 6] * num_repeats + pad_after = [ + max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences) + ] + + for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): + sequences[i] = ( + [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after + ) + attention_masks[i] = [0] * pad_before + attention_masks[i] + [0] * pad_after + + attention_masks = torch.tensor(attention_masks) + sequences = torch.tensor(sequences) + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_masks, + "action_log_probs": torch.tensor([[0.1] * num_actions] * batch_size), + "base_action_log_probs": torch.tensor([[0.2] * num_actions] * batch_size), + "rollout_logprobs": torch.tensor([[0.11] * num_actions] * batch_size), + "values": torch.tensor([[0.1] * num_actions] * batch_size), + "returns": torch.tensor([[0.1] * num_actions] * batch_size), + "advantages": torch.tensor([[0.5] * num_actions] * batch_size), + "loss_mask": torch.tensor([[1] * num_actions] * batch_size), + "response_mask": torch.tensor([[1] * num_actions] * batch_size), + } + ) + data.metadata = {"response_length": num_actions} + return data From 29295469997cc8ae14ba142c3f7f81bfa1fd3a44 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Tue, 18 Nov 2025 22:25:12 -0800 Subject: [PATCH 06/15] remove redundant branch --- image_processing/Dockerfile | 14 -- image_processing/README.md | 9 - image_processing/job.yaml | 62 ----- image_processing/process_images.py | 360 ----------------------------- image_processing/run.sh | 2 - 5 files changed, 447 deletions(-) delete mode 100644 image_processing/Dockerfile delete mode 100644 image_processing/README.md delete mode 100644 image_processing/job.yaml delete mode 100644 image_processing/process_images.py delete mode 100755 image_processing/run.sh diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile deleted file mode 100644 index d48065a..0000000 --- a/image_processing/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM anyscale/ray:2.51.1-slim-py312-cu128 - -# C compiler for Triton’s runtime build step (vLLM V1 engine) -# https://github.com/vllm-project/vllm/issues/2997 -RUN sudo apt-get update && \ - sudo apt-get install -y --no-install-recommends build-essential - -RUN curl -LsSf https://astral.sh/uv/install.sh | sh - -RUN uv pip install --system huggingface_hub boto3 - -RUN uv pip install --system vllm==0.11.0 - -RUN uv pip install --system transformers==4.57.1 \ No newline at end of file diff --git a/image_processing/README.md b/image_processing/README.md deleted file mode 100644 index 5ea8aca..0000000 --- a/image_processing/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Process images - -This example uses Ray Data to process the [ReLAION-2B](https://huggingface.co/datasets/laion/relaion2B-en-research-safe) image dataset, which consists of over 2 billion rows. Each row consists of an image URL along with various metadata include a caption and image dimensions. - -## Install the Anyscale CLI - -``` -anyscale job submit -f job.yaml --env HF_TOKEN=$HF_TOKEN -``` \ No newline at end of file diff --git a/image_processing/job.yaml b/image_processing/job.yaml deleted file mode 100644 index cdd32c0..0000000 --- a/image_processing/job.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. - -name: process-images - -# When empty, use the default image. This can be an Anyscale-provided base image -# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided -# that it meets certain specs), or you can build new images using the Anyscale -# image builder at https://console.anyscale-staging.com/v2/container-images. -# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 -containerfile: ./Dockerfile - -# When empty, Anyscale will auto-select the instance types. You can also specify -# minimum and maximum resources. -compute_config: - # OPTION 1: Auto-selection (current - works on Anyscale-hosted) - # Uses default disk sizes (~100GB). Cannot customize disk with auto-selection. - min_resources: - CPU: 0 - GPU: 0 - max_resources: - CPU: 520 - GPU: 128 - auto_select_worker_config: true - - # OPTION 2: Explicit config with custom disk (CUSTOMER-HOSTED ONLY) - # Uncomment below and comment out the auto-selection config above to use custom disk. - # NOTE: advanced_instance_config only works on customer-hosted AWS accounts. - # See DISK_SIZE_OPTIONS.md for details. - # - # head_node: - # instance_type: m5.2xlarge - # advanced_instance_config: - # BlockDeviceMappings: - # - DeviceName: /dev/sda1 - # Ebs: - # VolumeSize: 500 - # VolumeType: gp3 - # worker_nodes: - # - instance_type: m5.16xlarge - # min_nodes: 0 - # max_nodes: 100 - # advanced_instance_config: - # BlockDeviceMappings: - # - DeviceName: /dev/sda1 - # Ebs: - # VolumeSize: 500 - # VolumeType: gp3 - -# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that -# will be the working directory for the job. The files in the directory will be -# automatically uploaded to the job environment in Anyscale. -working_dir: . - -# When empty, this uses the default Anyscale Cloud in your organization. -cloud: - -# The script to run in your job. You can also do "uv run main.py" if you have a -# pyproject.toml file in your working_dir. -entrypoint: python process_images.py - -# If there is an error, do not retry. -max_retries: 0 \ No newline at end of file diff --git a/image_processing/process_images.py b/image_processing/process_images.py deleted file mode 100644 index 00d02d3..0000000 --- a/image_processing/process_images.py +++ /dev/null @@ -1,360 +0,0 @@ -import base64 -import concurrent.futures -import os -import ray -import requests -from huggingface_hub import HfFileSystem -from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor -from PIL import Image -from io import BytesIO -import pyarrow.fs as pafs -from requests.adapters import HTTPAdapter -import urllib3 -import logging -import warnings - -logging.getLogger("urllib3").setLevel(logging.ERROR) -logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) - - -# Disable SSL warnings since we're disabling verification for misconfigured image hosts -# Suppress urllib3 connection pool warnings (timeout, connection errors, etc.) -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - -# ============================================================================ -# SCALABILITY CONFIGURATION FOR 2B+ IMAGES -# ============================================================================ -# num_images = 100 -num_model_replicas = 32 -tensor_parallelism = 1 -max_concurrent_downloads = 10 # Reduced to minimize memory spikes (was 10) - -from datetime import datetime, timezone - -timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") - -output_path = f"/mnt/shared_storage/process_images_output/{timestamp}" - - -def create_session(): - """ - Create a requests session for image downloads without automatic retries. - """ - session = requests.Session() - - adapter = HTTPAdapter( - pool_connections=50, - pool_maxsize=50, - # Keep connections alive longer - pool_block=False, - ) - session.mount("http://", adapter) - session.mount("https://", adapter) - return session - - -def fetch_image(url, session=None): - """ - Fetch image with validation and error handling. - - If the download or validation fails, the image is marked as invalid without retrying. - """ - # Validate URL format first - if not url or not isinstance(url, str): - return None, False, "Invalid URL: empty or not a string" - - # Parse URL to properly validate its structure - from urllib.parse import urlparse - - try: - parsed = urlparse(url) - - # Check if URL has a valid scheme (http or https) - if parsed.scheme not in ("http", "https"): - return ( - None, - False, - f"Invalid URL", - ) - - # Check if URL has a valid hostname/netloc - # netloc is the network location (domain/IP), e.g., "example.com" or "192.168.1.1" - if not parsed.netloc or len(parsed.netloc.strip()) < 3: - return None, False, f"Invalid URL: missing or invalid hostname: {url[:100]}" - - # Check if netloc contains at least a dot (for domain) or is localhost/IP - # This catches URLs like "https://albumart/image.jpg" which have no valid domain - if "." not in parsed.netloc and parsed.netloc not in ("localhost", "127.0.0.1"): - # Allow if it looks like an IPv6 address (contains colons) - if ":" not in parsed.netloc: - return ( - None, - False, - f"Invalid URL: malformed hostname (missing domain): {url[:100]}", - ) - - except Exception as e: - return None, False, f"Invalid URL: failed to parse: {str(e)[:100]}" - - # Create session if not provided (will be reused within batch) - if session is None: - session = create_session() - - try: - response = session.get( - url, - timeout=(10, 20), # (connect_timeout=30s, read_timeout=60s) - verify=False, # Disable SSL verification for broken certs - allow_redirects=True, # Follow redirects - stream=False, # Download entire response - ) - except Exception as e: - return None, False, f"Error: {str(e)[:10]}" - - if response.status_code == 200: - ctype = response.headers.get("Content-Type", "") - if ctype.startswith("image"): - image_bytes = response.content - try: - with warnings.catch_warnings(): - warnings.filterwarnings("error", category=UserWarning) - warnings.filterwarnings( - "error", category=Image.DecompressionBombWarning - ) - # First verify the image format - img = Image.open(BytesIO(image_bytes)) - img.verify() # This checks if file is broken - # After verify(), we need to reopen to actually load the image - img = Image.open(BytesIO(image_bytes)) - img.load() # Force full image loading to detect truncation - img.close() - except (OSError, IOError) as e: - # Catch truncated images and other IO errors - error_msg = str(e)[:100] - if "truncated" in error_msg.lower(): - return None, False, f"Truncated image: {error_msg}" - return None, False, f"Image IO error: {error_msg}" - except Exception as e: - return None, False, f"Image validation error: {str(e)[:100]}" - return image_bytes, True, None - return None, False, f"Content-Type is not an image: {ctype}" - - return None, False, f"Status code: {response.status_code}" - - -def is_jpeg_format(row): - """Memory-efficient JPEG format check without keeping Image object in memory.""" - try: - image_data = row.get("image_base64") - if image_data is None: - return False - if isinstance(image_data, str): - image_data = base64.b64decode(image_data) - with Image.open(BytesIO(image_data)) as img: - return img.format == "JPEG" - except: - return False - - -def resize_image(row): - """Resize image to 128x128 pixels and standardize RGB values.""" - try: - image_data = row.get("image_base64") - if image_data is None: - return row - - # Decode base64 string to bytes - if isinstance(image_data, str): - image_bytes = base64.b64decode(image_data) - else: - image_bytes = image_data - - # Open image, convert to RGB, resize, and save back - with Image.open(BytesIO(image_bytes)) as img: - # Convert to RGB mode to ensure consistent 3-channel format - # This handles CMYK, grayscale, RGBA, etc. - if img.mode != "RGB": - img = img.convert("RGB") - - # Resize to 128x128 using high-quality Lanczos resampling - resized_img = img.resize((128, 128), Image.Resampling.LANCZOS) - - # Save resized image to bytes - output_buffer = BytesIO() - resized_img.save(output_buffer, format="JPEG", quality=95) - resized_bytes = output_buffer.getvalue() - - # Encode back to base64 string - row["image_base64"] = base64.b64encode(resized_bytes).decode("ascii") - return row - except Exception as e: - # If resize fails, keep original image - return row - - -def fetch_images_batch_threaded(batch): - """Fetch images in parallel with increased concurrency for network throughput.""" - # Disable SSL warnings in each Ray worker process - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - logging.getLogger("urllib3").setLevel(logging.ERROR) - logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) - - # Create a single shared session across threads in this batch - # Note: requests.Session is thread-safe for reading - session = create_session() - - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_concurrent_downloads - ) as executor: - # Pass session to each fetch_image call - results = list( - executor.map(lambda url: fetch_image(url, session), batch["url"]) - ) - - batch["image_base64"] = [ - ( - base64.b64encode(result[0]).decode("ascii") - if (result[0] is not None and result[1]) - else None - ) - for result in results - ] - batch["success"] = [result[1] for result in results] - return batch - - -vision_processor_config = vLLMEngineProcessorConfig( - model_source="Qwen/Qwen2.5-VL-3B-Instruct", - engine_kwargs=dict( - tensor_parallel_size=tensor_parallelism, - pipeline_parallel_size=1, - max_model_len=32768, - enable_chunked_prefill=True, - max_num_batched_tokens=2048, - ), - # Override Ray's runtime env to include the Hugging Face token. Ray Data uses Ray under the hood to orchestrate the inference pipeline. - runtime_env=dict( - env_vars=dict( - VLLM_USE_V1="1", - VLLM_DISABLE_COMPILE_CACHE="1", - ), - ), - batch_size=8, # Reduced from 16 to lower memory usage - max_concurrent_batches=16, # Increased to saturate vLLM engine (8 * 16 = 128) - accelerator_type="A10G", - concurrency=num_model_replicas, - has_image=True, -) - - -def vision_preprocess(row: dict) -> dict: - # Keep image data as base64 string for Arrow serialization - # The vLLM engine will handle the conversion internally - image_data = row["image_base64"] - - image_data = f"data:image;base64,{image_data}" - return dict( - messages=[ - { - "role": "user", - "content": [ - { - "type": "image", - "image": image_data, - }, - ], - }, - ], - sampling_params=dict( - temperature=0.3, - max_tokens=150, - detokenize=False, - ), - ) - - -def vision_postprocess(row: dict) -> dict: - row.pop("image_base64") - return row - - -vision_processor = build_llm_processor( - vision_processor_config, - preprocess=vision_preprocess, - postprocess=vision_postprocess, -) - - -# Initialize Ray with S3 spilling configuration -# This ensures Ray can access S3 for object spilling on all workers -if not ray.is_initialized(): - ray.init() - -ray.data.DataContext.get_current().retried_io_errors.extend( - [ - # Network connectivity errors - "Temporary failure in name resolution", - "Name or service not known", - "Max retries exceeded with url", - "Failed to establish a new connection", - "Connection refused", - "Connection timed out", - "Read timed out", - "ConnectTimeoutError", - "connect timeout", - "HTTPSConnectionPool", - "Remote end closed connection", - "Connection broken", - # SSL/TLS errors - "SSLError", - "SSL: CERTIFICATE_VERIFY_FAILED", - "hostname mismatch", - "certificate verify failed", - # Rate limiting - "429 Client Error: Too Many Requests", - "We had to rate limit you", - ] -) -num_cpu = 512 -tasks_per_cpu = 1 -concurrency = num_cpu * tasks_per_cpu -ctx = ray.data.DataContext.get_current() -target_block_size_mb = 128 -ctx.target_max_block_size = target_block_size_mb * 1024 * 1024 -ctx.use_push_based_shuffle = False - - -# Data pipeline with scalability optimizations -dataset = ( - ray.data.read_parquet( - "hf://datasets/laion/relaion2B-en-research-safe/", - file_extensions=["parquet"], - columns=["url"], - filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), - concurrency=concurrency, - num_cpus=2, - memory=int(4 * 1024**3), - ) - .map_batches( - fetch_images_batch_threaded, - batch_size=50, - memory=int(2 * 1024**3), - num_cpus=2, - ) # removed partition to reduce memory usage and redundency - .filter(lambda row: row["success"]) - .filter(is_jpeg_format) - .map(resize_image) - .drop_columns(["success"]) -) # Drop success column early to reduce memory - - -# Apply vision processing with scaled replicas -# Note: image_base64 column is dropped in vision_postprocess to avoid Arrow serialization issues -dataset = vision_processor(dataset) - -# Write with optimizations for throughput and fault tolerance -dataset.write_parquet( - output_path, - max_rows_per_file=100000, # ~100K rows per file for manageable file sizes -) diff --git a/image_processing/run.sh b/image_processing/run.sh deleted file mode 100755 index 8758376..0000000 --- a/image_processing/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -anyscale job submit -f job.yaml \ - --env HF_TOKEN=$HF_TOKEN \ No newline at end of file From dffc213d2b6a2983177618ec97c1f01691235a9c Mon Sep 17 00:00:00 2001 From: xyuzh Date: Tue, 18 Nov 2025 19:17:04 -0800 Subject: [PATCH 07/15] Add megatron_ray_fault_tolerant example with comprehensive fault tolerance - Implements PPO-style training with Megatron and Ray - Features automatic actor recovery from failures - Includes backup actor pool for seamless replacement - Supports DP, TP, PP, and CP parallelism - Distributed checkpoint saving/loading - Process group re-initialization after failures - Added comprehensive documentation in README files --- README.md | 102 ++ megatron_ray_fault_tolerant/.gitignore | 1 + .../.pre-commit-config.yaml | 20 + megatron_ray_fault_tolerant/Dockerfile | 34 + megatron_ray_fault_tolerant/README.md | 191 ++++ megatron_ray_fault_tolerant/dispatch.py | 299 ++++++ megatron_ray_fault_tolerant/file_io.py | 321 ++++++ megatron_ray_fault_tolerant/job.yaml | 45 + megatron_ray_fault_tolerant/main.py | 190 ++++ megatron_ray_fault_tolerant/megatron_actor.py | 934 ++++++++++++++++++ .../megatron_model_utils.py | 442 +++++++++ .../megatron_model_wrapper.py | 171 ++++ megatron_ray_fault_tolerant/megatron_utils.py | 465 +++++++++ megatron_ray_fault_tolerant/optimizer.py | 103 ++ megatron_ray_fault_tolerant/pyproject.toml | 98 ++ megatron_ray_fault_tolerant/run.sh | 1 + megatron_ray_fault_tolerant/training_batch.py | 371 +++++++ megatron_ray_fault_tolerant/utils.py | 286 ++++++ 18 files changed, 4074 insertions(+) create mode 100644 README.md create mode 100644 megatron_ray_fault_tolerant/.gitignore create mode 100644 megatron_ray_fault_tolerant/.pre-commit-config.yaml create mode 100644 megatron_ray_fault_tolerant/Dockerfile create mode 100644 megatron_ray_fault_tolerant/README.md create mode 100644 megatron_ray_fault_tolerant/dispatch.py create mode 100644 megatron_ray_fault_tolerant/file_io.py create mode 100644 megatron_ray_fault_tolerant/job.yaml create mode 100644 megatron_ray_fault_tolerant/main.py create mode 100644 megatron_ray_fault_tolerant/megatron_actor.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_utils.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_wrapper.py create mode 100644 megatron_ray_fault_tolerant/megatron_utils.py create mode 100644 megatron_ray_fault_tolerant/optimizer.py create mode 100644 megatron_ray_fault_tolerant/pyproject.toml create mode 100755 megatron_ray_fault_tolerant/run.sh create mode 100644 megatron_ray_fault_tolerant/training_batch.py create mode 100644 megatron_ray_fault_tolerant/utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..3a7e424 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# Examples + +This repository contains examples for deploying and running distributed applications. + +## Job Examples + +### 1. Hello World Job +**Directory:** `01_job_hello_world/` + +A simple "Hello World" example demonstrating how to submit and run basic jobs. + +### 2. Image Processing +**Directory:** `image_processing/` + +Process large-scale image datasets using Ray Data. This example demonstrates processing the ReLAION-2B dataset with over 2 billion rows. + +### 3. Megatron + Ray Fault Tolerant Training +**Directory:** `megatron_ray_fault_tolerant/` + +Implements PPO-style distributed training with Megatron and Ray, featuring comprehensive fault tolerance capabilities: +- Automatic actor recovery from failures +- Backup actor groups for seamless replacement +- Distributed checkpoint saving/loading +- Process group re-initialization after failures +- Support for tensor, pipeline, data, and context parallelism + +## Service Examples + +### 1. Hello World Service +**Directory:** `02_service_hello_world/` + +A simple service deployment example demonstrating the basics of Ray Serve. + +### 2. Deploy Llama 3.1 8B +**Directory:** `03_deploy_llama_3_8b/` + +Deploy Llama 3.1 8B model using Ray Serve and vLLM with autoscaling capabilities. + +### 3. Deploy Llama 3.1 70B +**Directory:** `deploy_llama_3_1_70b/` + +Deploy the larger Llama 3.1 70B model with optimized serving configuration. + +### 4. Tensor Parallel Serving +**Directory:** `serve_tensor_parallel/` + +Demonstrates tensor parallelism for serving large language models across multiple GPUs. + +### 5. FastVideo Generation +**Directory:** `video_generation_with_fastvideo/` + +Deploy a video generation service using the FastVideo framework. + +## Reinforcement Learning Examples + +### SkyRL +**Directory:** `skyrl/` + +Reinforcement learning training example using Ray and distributed computing. + +## Getting Started + +Most examples include their own README with specific instructions. Generally, you'll need: + +1. Install the Anyscale CLI: +```bash +pip install -U anyscale +anyscale login +``` + +2. Navigate to the example directory: +```bash +cd +``` + +3. Deploy the service or submit the job: +```bash +# For services +anyscale service deploy -f service.yaml + +# For jobs +anyscale job submit -f job.yaml +``` + +## Requirements + +- Anyscale account and CLI access +- Appropriate cloud credentials configured +- GPU resources for ML/LLM examples + +## Contributing + +When adding new examples: +1. Create a descriptive directory name +2. Include a README.md with setup and usage instructions +3. Add appropriate YAML configuration files +4. Update this main README with your example + +## License + +See individual example directories for specific licensing information. + diff --git a/megatron_ray_fault_tolerant/.gitignore b/megatron_ray_fault_tolerant/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/megatron_ray_fault_tolerant/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/.pre-commit-config.yaml b/megatron_ray_fault_tolerant/.pre-commit-config.yaml new file mode 100644 index 0000000..5d51437 --- /dev/null +++ b/megatron_ray_fault_tolerant/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.9 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + exclude: (^(skyagent)/.*)$ + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + exclude: (^(skyagent)/.*)$ + + # Detect secrets and sensitive information + - repo: https://github.com/gitleaks/gitleaks + rev: v8.24.2 + hooks: + - id: gitleaks \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/Dockerfile b/megatron_ray_fault_tolerant/Dockerfile new file mode 100644 index 0000000..787c1c7 --- /dev/null +++ b/megatron_ray_fault_tolerant/Dockerfile @@ -0,0 +1,34 @@ +FROM anyscale/ray:2.51.0-slim-py312-cu128 + +RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev + +# the cuda compiler here is needed for deepspeed +RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \ + && sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run + +RUN curl -LsSf https://astral.sh/uv/0.9.4/install.sh | sh +RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc + + +RUN sudo apt-get update \ + && sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \ + libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \ + && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \ + && sudo apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +# ---------- PyTorch + cuDNN + Transformer Engine ---------- +# PyTorch + cuDNN + Transformer Engine +RUN pip install --no-cache-dir "torch==2.7.1" "nvidia-cudnn-cu12>=9.3" && \ + CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" && \ + sudo mkdir -p /opt && sudo ln -sfn "$CUDNN_PATH" /opt/cudnn && \ + echo "/opt/cudnn/lib" | sudo tee /etc/ld.so.conf.d/cudnn.conf >/dev/null && sudo ldconfig + +ENV CUDNN_PATH=/opt/cudnn +ENV CPATH=${CUDNN_PATH}/include:${CPATH} +ENV LD_LIBRARY_PATH=${CUDNN_PATH}/lib:${LD_LIBRARY_PATH} + +RUN pip install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.5.0" +# -------------------- diff --git a/megatron_ray_fault_tolerant/README.md b/megatron_ray_fault_tolerant/README.md new file mode 100644 index 0000000..b7abecc --- /dev/null +++ b/megatron_ray_fault_tolerant/README.md @@ -0,0 +1,191 @@ +# Megatron + Ray Fault Tolerant Training + +This example implements PPO-style distributed training using Megatron and Ray with comprehensive fault tolerance capabilities. The system can automatically recover from actor failures during training by utilizing backup actors and re-initializing process groups. + +## Key Features + +### Fault Tolerance Mechanisms + +1. **Actor Health Monitoring**: Continuously monitors the health of distributed training actors +2. **Backup Actor Pool**: Pre-allocated backup actors ready to replace failed workers +3. **Automatic Recovery**: Seamlessly recovers from failures by: + - Detecting dead actors + - Destroying old process groups + - Replacing failed actors with backup actors + - Re-initializing process groups with new world size + - Reloading model and optimizer state from checkpoints + +4. **Distributed Checkpointing**: Implements efficient sharded checkpoint saving/loading using Megatron's distributed checkpointing +5. **Process Group Management**: Handles NCCL process group initialization, destruction, and re-initialization + +### Parallelism Support + +- **Data Parallelism (DP)**: Distributes training data across multiple GPUs +- **Tensor Parallelism (TP)**: Splits model tensors across GPUs +- **Pipeline Parallelism (PP)**: Distributes model layers across GPUs +- **Context Parallelism (CP)**: Enables sequence parallelism for long contexts + +### Advanced Training Features + +- **PPO Training**: Implements Proximal Policy Optimization with micro-batch accumulation +- **Mixed Precision**: Supports BF16 training for improved performance +- **Gradient Accumulation**: Handles micro-batches with automatic gradient accumulation +- **Distributed Optimizer**: Uses Megatron's distributed optimizer for memory efficiency + +## Architecture + +### Core Components + +1. **MegatronActor** (`megatron_actor.py`): + - Individual training actor wrapping Megatron models + - Handles model initialization, forward/backward passes, and checkpointing + - Supports dynamic process group re-initialization + +2. **MegatronActorGroup** (`megatron_actor.py`): + - Manages a group of distributed actors + - Implements fault recovery logic + - Coordinates distributed training operations + +3. **Dispatch System** (`dispatch.py`): + - **MeshDispatch**: Distributes data across the device mesh (DP, SP, TP, PP) + - **PassThroughDispatch**: Broadcasts same data/commands to all actors + - Handles data sharding and result collection + +4. **Training Batch** (`training_batch.py`): + - Defines input/output batch structures for PPO training + - Supports chunking and concatenation for distributed operations + +5. **Checkpoint I/O** (`file_io.py`): + - Cloud-aware file I/O supporting S3, GCS, and local storage + - Efficient checkpoint upload/download with parallel transfers + +## Getting Started + +### Quick Start + +```bash +uv run --isolated main.py +``` + +This will: +1. Create a placement group with workers and backup GPUs +2. Initialize the actor group and model +3. Run a training step +4. Save a checkpoint +5. Simulate a failure by killing actors +6. Recover from the failure using backup actors +7. Resume training after recovery + +### Configuration + +Edit the `Config` class in `main.py` to customize: + +```python +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" # HuggingFace model name + num_nodes: int = 1 + num_gpus_per_node: int = 4 + num_spare_gpus: int = 4 # Backup actors for fault tolerance + mini_batch_size: int = 16 + micro_train_batch_size_per_gpu: int = 2 + + # Megatron parallelism settings + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) +``` + +### Megatron Parallelism Configuration + +```python +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 # TP degree + pipeline_model_parallel_size: int = 1 # PP degree + context_parallel_size: int = 1 # CP degree + expert_model_parallel_size: int = 1 # For MoE models +``` + +## Fault Recovery Workflow + +1. **Training Phase**: + - Actors perform distributed training using Megatron + - Periodic checkpoints saved to cloud storage + +2. **Failure Detection**: + - System detects actor failures via health checks + - Identifies affected data parallel groups + +3. **Recovery Process**: + - Destroy old process groups on healthy actors + - Pop backup actors from the backup pool + - Insert backup actors at failed ranks + - Update world size and reassign ranks + - Re-initialize process groups with new configuration + - Reload model/optimizer state from checkpoint + +4. **Resume Training**: + - Continue training with recovered actor group + - No loss of training progress (from last checkpoint) + +## Advanced Usage + +### Custom Dispatch Types + +Register custom dispatch strategies: + +```python +from dispatch import register_dispatch_type, Dispatch + +class CustomDispatch(Dispatch): + # Implement dispatch, collect, and validate methods + pass + +register_dispatch_type("custom", CustomDispatch) +``` + +### CPU Offloading (Experimental) + +For faster recovery, offload model/optimizer state to CPU memory: + +```python +# Before failure +ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + +# After recovery, on healthy actors +ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) +``` + +## Dependencies + +See `pyproject.toml` for full dependency list. Key dependencies: +- Ray for distributed orchestration +- Megatron-Core for model parallelism +- PyTorch with CUDA support +- Transformers for model loading +- vLLM and related libraries + +## Running on Anyscale + +Submit the job using: + +```bash +anyscale job submit -f job.yaml +``` + +The job configuration in `job.yaml` specifies: +- Container image with dependencies +- GPU instance types (g6e.12xlarge with 4xL4) +- Resource limits and scaling +- Environment variables for NCCL configuration + +## Limitations and Future Work + +- Virtual pipeline parallelism not yet supported +- CPU offloading optimization in progress +- Async checkpoint saving planned for future releases + +## References + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +- [Ray Documentation](https://docs.ray.io/) +- [Anyscale Platform](https://docs.anyscale.com/) diff --git a/megatron_ray_fault_tolerant/dispatch.py b/megatron_ray_fault_tolerant/dispatch.py new file mode 100644 index 0000000..9949c48 --- /dev/null +++ b/megatron_ray_fault_tolerant/dispatch.py @@ -0,0 +1,299 @@ +"""Defines dispatch and collect logic for distributed training""" + +from dataclasses import dataclass +from ray.actor import ActorHandle +from typing import List, Tuple, Optional, Dict, Type, Any +import asyncio +from abc import ABC, abstractmethod +import ray +from ray import ObjectRef +from training_batch import TrainingInputBatch, TrainingOutputBatch +import inspect + + +@dataclass +class MeshRank: + """Represents a rank in the device mesh. + + This is a tuple of (DP, SP, TP, PP) ranks. + """ + + dp: int + sp: int + tp: int + pp: int + + world_size: int + dp_size: int + pp_size: int + + def is_collection_dp_rank(self) -> bool: + """Check if this rank is a DP rank to collect from + + This is the rank with (SP=0, TP=0, PP=pp_size-1) + + Note: double check this for ETP > 1 (but this is not a typically used case) + """ + return self.tp == 0 and self.pp == self.pp_size - 1 and self.sp == 0 + + def __str__(self) -> str: + return f"MeshRank(dp={self.dp}, sp={self.sp}, tp={self.tp}, pp={self.pp}, world_size={self.world_size}, dp_size={self.dp_size}, pp_size={self.pp_size})" + + def __repr__(self) -> str: + return self.__str__() + + +@dataclass +class ActorInfo: + """Actor information for distributed training. + + This includes the actor handle and the rank in the device mesh. + """ + + handle: ActorHandle + rank: MeshRank + + +class Dispatch(ABC): + """Base class for dispatch types + + Dispatch types are responsible for: + - dispatching method calls to actors handling data sharding if necessary + - collecting results from actors and concatenating results if necessary + - validating arguments for dispatch + """ + + @classmethod + @abstractmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + """Dispatches method calls to the actors with data sharing if necessary.""" + pass + + @classmethod + @abstractmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors asynchronously in an asyncio-compatible way.""" + pass + + @classmethod + @abstractmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors synchronously and returns a `TrainingOutputBatch`.""" + pass + + @classmethod + @abstractmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + """Validate and process arguments for dispatch. + + Returns: + Tuple of (args, kwargs) to be passed to dispatch + """ + pass + + +class MeshDispatch(Dispatch): + """Mesh dispatch type to dispatch data to a group of actors along the device mesh. + + Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. + The actor method should accept a single argument - the data batch. + + For data dispatch: + + * The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism. + * Each actor with the same DP rank processes the same data chunk in parallel. + + For data collection: + + * Data is collected only from the primary rank of each model/sequence parallel group. + * The primary rank is defined as the rank with (SP=0, TP=0, PP=0). + * The collected chunks are concatenated in order of DP rank to reconstruct the full data. + + Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1: + + * Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, + and all actors with DP rank 1 process the second chunk. + * Data collection: Only two actors contribute to the final output - the primary rank from each DP group: + (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order. + + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch + ) -> List[ObjectRef]: + assert len(actor_infos) > 0, "actor_infos must be a non-empty list" + object_refs = [] + dp_size = actor_infos[0].rank.dp_size + assert ( + len(data) % dp_size == 0 + ), "data batch size must be divisible by dp_size, got {} and {}".format( + len(data), dp_size + ) + chunk_size = len(data) // dp_size + data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size) + + for actor_info in actor_infos: + # index into tensordict to get the correct data to send + data_to_send = data_chunks[actor_info.rank.dp] + object_refs.append(getattr(actor_info.handle, method).remote(data_to_send)) + return object_refs + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = ray.get(object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + # all should be none + assert all( + obj is None for obj in all_objects + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + sig = inspect.signature(cls.dispatch) + # pass dummy actor_infos and method_name + bound_args = sig.bind([], "dummy", *args, **kwargs) + bound_args.apply_defaults() + data = bound_args.arguments.get("data") + + # Check if there are any extra arguments + if len(bound_args.arguments) > 3: # data, actor_infos, method_name + # remove actor_infos and method_name - not added by user + bound_args.arguments.pop("actor_infos") + bound_args.arguments.pop("method") + raise ValueError( + f"MeshDispatch only accepts 'data' as an argument, got extra args: {bound_args.arguments}" + ) + + data = bound_args.arguments.get("data") + if not isinstance(data, TrainingInputBatch): + raise ValueError( + f"For MeshDispatch, `data` entry should be a `TrainingInput`, got {data}" + ) + args = (data,) + kwargs = {} + return args, kwargs + + +class PassThroughDispatch(Dispatch): + """PassThrough dispatch type to dispatch data to a group of actors without any sharding. + + This is useful for cases where we want to run the same method on all the actors. + Supports methods with any number of arguments. + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + return [ + getattr(actor_info.handle, method).remote(*args, **kwargs) + for actor_info in actor_infos + ] + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + data_batches = ray.get(object_refs) + if len(data_batches) > 0 and data_batches[0] is not None: + assert isinstance( + data_batches[0], TrainingOutputBatch + ), "data_batches must be a list of `TrainingOutputBatch` objects" + return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches) + # all should be none + assert all( + obj is None for obj in data_batches + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + # no validation needed just pass everything + return args, kwargs + + +class DispatchRegistry: + _registry: Dict[str, Type[Dispatch]] = { + "mesh": MeshDispatch, + "pass_through": PassThroughDispatch, + } + + @classmethod + def register(cls, name: str, dispatch_class: Type[Dispatch]) -> None: + """Register a new dispatch type.""" + assert issubclass(dispatch_class, Dispatch) + cls._registry[name] = dispatch_class + + @classmethod + def get(cls, name: str) -> Type[Dispatch]: + """Get a registered dispatch type.""" + if name not in cls._registry: + raise KeyError(f"Dispatch type '{name}' not registered") + return cls._registry[name] + + @classmethod + def list_registered(cls) -> Dict[str, Type[Dispatch]]: + """List all registered dispatch types.""" + return cls._registry + + +def register_dispatch_type(name: str, dispatch_class: Type) -> None: + DispatchRegistry.register(name, dispatch_class) + + +def concatenate_outputs_after_mesh_dispatch( + actor_infos: List[ActorInfo], data_batches: List[TrainingOutputBatch] +) -> TrainingOutputBatch: + """Concatenate data batches from different ranks after mesh dispatch. + + - Data is collected only from the primary DP rank. + - The collected chunks are concatenated in order of DP rank to reconstruct the full data. + """ + assert len(actor_infos) == len( + data_batches + ), "`actor_infos` and `data_batches` must have the same length" + shards = [] + # collect in-order + dp_rank_to_shard = {} + for actor_info, data_batch in zip(actor_infos, data_batches): + if actor_info.rank.is_collection_dp_rank(): + dp_rank = actor_info.rank.dp + dp_rank_to_shard[dp_rank] = data_batch + for i in range(actor_infos[0].rank.dp_size): + shards.append(dp_rank_to_shard[i]) + return TrainingOutputBatch.cat(shards) diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py new file mode 100644 index 0000000..932adbe --- /dev/null +++ b/megatron_ray_fault_tolerant/file_io.py @@ -0,0 +1,321 @@ +""" +File I/O utilities for handling both local filesystem and cloud storage (S3/GCS). + +This module provides a unified interface for file operations that works with: +- Local filesystem paths +- S3 paths (s3://bucket/path) +- Google Cloud Storage paths (gs://bucket/path or gcs://bucket/path) + +Uses fsspec for cloud storage abstraction. +""" + +import os +import tempfile +from contextlib import contextmanager +import fsspec +from loguru import logger +from datetime import datetime, timezone, timedelta + +# Optional AWS deps (present when s3fs is installed) +try: + import botocore.session as _botocore_session + from botocore.exceptions import ClientError + + _HAS_BOTOCORE = True +except Exception: + _HAS_BOTOCORE = False + + class ClientError(Exception): # fallback type + pass + + +_S3_FS = None # type: ignore + + +def get_s3_fs(): + """Return a cached S3 filesystem instance, creating it once.""" + global _S3_FS + if _S3_FS is None: + _S3_FS = fsspec.filesystem("s3") + return _S3_FS + + +def s3_expiry_time(): + """Return botocore credential expiry (datetime in UTC) or None.""" + if not _HAS_BOTOCORE: + return None + try: + sess = _botocore_session.get_session() + creds = sess.get_credentials() + if not creds: + return None + return getattr(creds, "expiry_time", None) or getattr( + creds, "_expiry_time", None + ) + except Exception: + return None + + +def s3_refresh_if_expiring(fs) -> None: + """ + Simple refresh: + - If expiry exists and is within 300s (or past), refresh with fs.connect(refresh=True). + - Otherwise, do nothing. + """ + exp = s3_expiry_time() + if not exp: + return + now = datetime.now(timezone.utc) + if now >= exp - timedelta(seconds=300): + try: + fs.connect(refresh=True) # rebuild session + except Exception: + pass + + +def call_with_s3_retry(fs, fn, *args, **kwargs): + """ + Wrapper for calling an S3 method. If it fails with ExpiredToken, force refresh once and retry. + """ + try: + return fn(*args, **kwargs) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fn(*args, **kwargs) + raise + + +def is_cloud_path(path: str) -> bool: + """Check if the given path is a cloud storage path.""" + return path.startswith(("s3://", "gs://", "gcs://")) + + +def _get_filesystem(path: str): + """Get the appropriate filesystem for the given path.""" + if not is_cloud_path(path): + return fsspec.filesystem("file") + + proto = path.split("://", 1)[0] + if proto == "s3": + fs = get_s3_fs() + s3_refresh_if_expiring(fs) + return fs + return fsspec.filesystem(proto) + + +def open_file(path: str, mode: str = "rb"): + """Open a file using fsspec, works with both local and cloud paths.""" + if not is_cloud_path(path): + return fsspec.open(path, mode) + + fs = _get_filesystem(path) + norm = fs._strip_protocol(path) + try: + return fs.open(norm, mode) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fs.open(norm, mode) + raise + + +def makedirs(path: str, exist_ok: bool = True) -> None: + """Create directories. Only applies to local filesystem paths.""" + if not is_cloud_path(path): + os.makedirs(path, exist_ok=exist_ok) + + +def exists(path: str) -> bool: + """Check if a file or directory exists.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.exists, path) + return fs.exists(path) + + +def isdir(path: str) -> bool: + """Check if path is a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.isdir, path) + return fs.isdir(path) + + +def list_dir(path: str) -> list[str]: + """List contents of a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.ls, path, detail=False) + return fs.ls(path, detail=False) + + +def remove(path: str) -> None: + """Remove a file or directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + if call_with_s3_retry(fs, fs.isdir, path): + call_with_s3_retry(fs, fs.rm, path, recursive=True) + else: + call_with_s3_retry(fs, fs.rm, path) + return + if fs.isdir(path): + fs.rm(path, recursive=True) + else: + fs.rm(path) + + +def upload_directory(local_path: str, cloud_path: str) -> None: + """Upload a local directory to cloud storage. + + Uploads the contents of local_path to cloud_path, not the directory itself. + This ensures consistent behavior across all ranks by explicitly uploading each file. + """ + if not is_cloud_path(cloud_path): + raise ValueError(f"Destination must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + + # Normalize paths: ensure cloud_path ends with / to indicate directory + cloud_path_normalized = cloud_path.rstrip("/") + "/" + + # Walk the local directory and upload each file explicitly + # This ensures we upload contents, not the directory as a subdirectory + for root, dirs, files in os.walk(local_path): + for file in files: + local_file_path = os.path.join(root, file) + # Get relative path from local_path to maintain directory structure + rel_path = os.path.relpath(local_file_path, local_path) + # Construct remote path: cloud_path/rel_path + remote_file_path = cloud_path_normalized + rel_path + + if cloud_path.startswith("s3://"): + # For S3, strip protocol for fsspec operations + remote_file_path_stripped = fs._strip_protocol(remote_file_path) + # Ensure parent directories exist in S3 (fsspec handles this automatically) + call_with_s3_retry( + fs, fs.put, local_file_path, remote_file_path_stripped + ) + else: + fs.put(local_file_path, remote_file_path) + + logger.info(f"Uploaded contents of {local_path} to {cloud_path}") + + +def download_directory(cloud_path: str, local_path: str) -> None: + """Download a cloud directory to local storage.""" + if not is_cloud_path(cloud_path): + raise ValueError(f"Source must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + cloud_path_normalized = cloud_path.rstrip("/") + "/" + os.makedirs(local_path, exist_ok=True) + + # List all files and download each one individually to download contents, not the folder + if cloud_path.startswith("s3://"): + remote_path_stripped = fs._strip_protocol(cloud_path_normalized) + all_files = call_with_s3_retry(fs, fs.find, remote_path_stripped, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(remote_path_stripped) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + call_with_s3_retry(fs, fs.get, remote_file, local_file_path) + else: + all_files = fs.find(cloud_path_normalized, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(cloud_path_normalized) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + fs.get(remote_file, local_file_path) + + logger.info(f"Downloaded {cloud_path} to {local_path}") + + +@contextmanager +def local_work_dir(output_path: str): + """ + Context manager that provides a local working directory. + + For local paths, returns the path directly. + For cloud paths, creates a temporary directory and uploads content at the end. + + Args: + output_path: The final destination path (local or cloud) + + Yields: + str: Local directory path to work with + + Example: + with local_work_dir("s3://bucket/model") as work_dir: + # Save files to work_dir + model.save_pretrained(work_dir) + # Files are automatically uploaded to s3://bucket/model at context exit + """ + if is_cloud_path(output_path): + with tempfile.TemporaryDirectory() as temp_dir: + try: + yield temp_dir + finally: + # Upload everything from temp_dir to cloud path + upload_directory(temp_dir, output_path) + logger.info(f"Uploaded directory contents to {output_path}") + else: + # For local paths, ensure directory exists and use it directly + makedirs(output_path, exist_ok=True) + yield output_path + + +@contextmanager +def local_read_dir(input_path: str): + """ + Context manager that provides a local directory with content from input_path. + + For local paths, returns the path directly. + For cloud paths, downloads content to a temporary directory. + + Args: + input_path: The source path (local or cloud) + + Yields: + str: Local directory path containing the content + + Example: + with local_read_dir("s3://bucket/model") as read_dir: + # Load files from read_dir + model = AutoModel.from_pretrained(read_dir) + """ + if is_cloud_path(input_path): + with tempfile.TemporaryDirectory() as temp_dir: + # Download everything from cloud path to temp_dir + download_directory(input_path, temp_dir) + logger.info(f"Downloaded directory contents from {input_path}") + yield temp_dir + else: + # For local paths, use directly (but check it exists) + if not exists(input_path): + raise FileNotFoundError(f"Path does not exist: {input_path}") + yield input_path diff --git a/megatron_ray_fault_tolerant/job.yaml b/megatron_ray_fault_tolerant/job.yaml new file mode 100644 index 0000000..f1c2de2 --- /dev/null +++ b/megatron_ray_fault_tolerant/job.yaml @@ -0,0 +1,45 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. + +name: megatron-fault-tolerance + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: + # Pin worker nodes to g6.xlarge (1xL4) so the vision workload lands on L4 GPUs. + worker_nodes: + - instance_type: g6e.12xlarge + min_nodes: 0 + max_nodes: 2 + min_resources: + CPU: 0 + GPU: 0 + max_resources: + CPU: 384 + GPU: 64 + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +env_vars: + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + NCCL_P2P_DISABLE: "1" + NCCL_SHM_DISABLE: "1" + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: uv run --isolated main.py + +# If there is an error, do not retry. +max_retries: 0 \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py new file mode 100644 index 0000000..b64b535 --- /dev/null +++ b/megatron_ray_fault_tolerant/main.py @@ -0,0 +1,190 @@ +import os +from dataclasses import dataclass, field +import ray +from typing import Optional, List +from megatron_actor import MegatronActorGroup +from ray.util.placement_group import placement_group + +import random +import time +from utils import get_test_training_batch, get_reordered_bundle_indices + + +@dataclass +class DDPConfig: + grad_reduce_in_fp32: bool = True + overlap_grad_reduce: bool = False + overlap_param_gather: bool = False + average_in_collective: bool = True + + +@dataclass +class OptimizerConfig: + lr: float = 1.0e-6 + weight_decay: float = 1e-2 + max_grad_norm: float = 1.0 + offload_after_step: bool = True + num_warmup_steps: int = 0 + scheduler: str = "constant_with_warmup" + + +@dataclass +class TransformerConfig: + recompute_granularity: Optional[str] = None + recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"]) + recompute_method: Optional[str] = None + recompute_num_layers: Optional[int] = None + + +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + context_parallel_size: int = 1 + expert_model_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + ddp_config: DDPConfig = field(default_factory=DDPConfig) + optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) + transformer_config: TransformerConfig = field(default_factory=TransformerConfig) + + +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" + # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it + num_nodes: int = 1 + num_gpus_per_node: int = 4 + mini_batch_size: int = 16 + num_spare_gpus: int = 4 + micro_train_batch_size_per_gpu: int = 2 + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) + ckpt_dir: str = ( + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt3/" + ) + # algorithm config + eps_clip_low: float = 0.2 + eps_clip_high: float = 0.2 + clip_ratio_c: float = 3.0 + + +def main(): + config = Config() + # create placement group including spare gpus + pg = placement_group( + [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node + + [{"GPU": 1, "CPU": 1}] * config.num_spare_gpus, + strategy="PACK", + ) + ray.get(pg.ready(), timeout=1200) + # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 + reordered_bundle_indices = get_reordered_bundle_indices(pg) + + actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_nodes, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[:-config.num_spare_gpus], + ) + actor_group.initiate_worker_process_group() + ray.get(actor_group.async_init_model(config.model)) + + # potentially need some time for dependencies like transformer-engine-torch to build on worker nodes (this is something good to warm start...) + backup_actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_spare_gpus // config.num_gpus_per_node, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[-config.num_spare_gpus:], + ) + # just place but don't initiate the worker process group for the backup actor group + # call a function to make sure the actors are placed + ray.get(backup_actor_group.async_run_method_no_dispatch("get_gpu_id")) + + # train on one batch + batch = get_test_training_batch(config.model, batch_size=32) + print("Starting training step 1...") + start_time = time.time() + ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) + print(f"Training step 1 took {time.time() - start_time:.2f} seconds") + + # save checkpoint + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") + + # TODO: add a cpu offload (or cpu save memory) call here + # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + + # TODO: run another training batch here and save results but don't save checkpoint + + # randomly kill an actor to simulate fault tolerance scenario + # TODO: go deeper into the actor code and throw an exception on a given node and catch it here + print("Simulating failure and recovery...") + start_time = time.time() + + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) + # get the whole dp group associated with the failed actor + dp_group_actors = [] + for actor_info in actor_group.actor_infos: + if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: + dp_group_actors.append(actor_info) + print( + f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." + ) + for actor_info in dp_group_actors: + ray.kill(actor_info.handle) + + # Destroy process groups on all actors (including dead ones, which will fail gracefully) + print("Destroying old process groups...") + try: + ray.get( + actor_group.async_run_ray_method( + "pass_through", "destroy_worker_process_group" + ) + ) + except Exception as e: + print(f"Some actors failed during destroy (expected): {e}") + + for i, actor_info in enumerate(actor_group.actor_infos): + is_alive = actor_group._check_actor_alive(actor_info.handle) + print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") + + # Recover from failure: remove dead actors and re-initialize process group + print("Recovering from actor failure...") + actor_group.recover_from_failure(backup_actor_group) + + # load checkpoint on all actors + # TODO: improve the logic here + # we want to only call load checkpoint on the actors that are fresh + # on previously healthy actors we want to restore weights and optimizer state from cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu"), actor_ids=[previously healthy actor ids]) + # only for new actors, we want to load the checkpoint + ray.get( + actor_group.async_run_ray_method( + "pass_through", "load_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Recovery took {time.time() - start_time:.2f} seconds") + + # TODO: check that results here are the same as before the failure when resuming from checkpoint + # Test that training still works after recovery + print("Testing training after recovery...") + batch_after_recovery = get_test_training_batch(config.model, batch_size=32) + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "ppo_train", batch_after_recovery + ) + ) + print(f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds") + print("Recovery successful! Training works with remaining actors.") + + +if __name__ == "__main__": + main() diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py new file mode 100644 index 0000000..c1789de --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -0,0 +1,934 @@ +import logging +import os +import random +import socket +from dataclasses import asdict +from tqdm import tqdm +from typing import Optional, Dict, Any, List +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +import ray +from ray import ObjectRef +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer +from loguru import logger + +# megatron +from megatron.bridge import AutoBridge +import megatron.core.parallel_state as mpu +from megatron.core import dist_checkpointing +from megatron.core.dist_checkpointing.strategies import base as ckpt_base +from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) + +# local imports +import file_io as io # local io module to support cloud storage for checkpointing +from training_batch import TrainingOutputBatch +from optimizer import ( + init_megatron_optim_config, + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, +) +from megatron_model_wrapper import MegatronModelWrapper +from megatron_utils import ( + offload_megatron_model_to_cpu, + offload_megatron_optimizer, + load_megatron_model_to_gpu, + load_megatron_optimizer, + offload_megatron_grads_to_cpu, + load_megatron_grads_to_gpu, +) +from utils import BatchIterator +from dispatch import DispatchRegistry, Dispatch, ActorInfo, MeshRank + + +@ray.remote(num_gpus=1) +class MegatronActor: + def __init__( + self, + world_size, + rank, + local_rank, + master_addr, + master_port, + megatron_config, + seed, + cfg, + ): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self._world_size = world_size + self._rank = rank + self._local_rank = local_rank + self._master_addr = master_addr if master_addr else self._get_current_node_ip() + self._master_port = master_port if master_port else self._get_free_port() + os.environ["MASTER_ADDR"] = self._master_addr + os.environ["MASTER_PORT"] = str(self._master_port) + os.environ["WORLD_SIZE"] = str(self._world_size) + os.environ["RANK"] = str(self._rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + os.environ["LOCAL_RANK"] = "0" + self.megatron_config = megatron_config + self.seed = seed + self.cfg = cfg + + def get_node_local_rank(self): + return self._local_rank + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if torch.cuda.device_count() > 0: + from megatron.core import tensor_parallel + + tensor_parallel.model_parallel_cuda_manual_seed(seed) + + def init_worker_process_group(self): + """Initialize worker process group and megatron model parallel.""" + # Destroy any existing process group first to ensure clean state + if torch.distributed.is_initialized(): + try: + torch.distributed.destroy_process_group() + except Exception: + pass # Ignore errors if already destroyed + + # Initialize process group using environment variables + torch.distributed.init_process_group(backend="nccl") + + local_rank = int(os.environ.get("LOCAL_RANK", "-1")) + if local_rank != -1: + torch.cuda.set_device(local_rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.megatron_config.tensor_model_parallel_size, + pipeline_model_parallel_size=self.megatron_config.pipeline_model_parallel_size, + expert_model_parallel_size=self.megatron_config.expert_model_parallel_size, + expert_tensor_parallel_size=self.megatron_config.expert_tensor_parallel_size, + use_sharp=False, + context_parallel_size=self.megatron_config.context_parallel_size, + nccl_communicator_config_path=None, + ) + self.set_seed(self.seed) + self.world_size = dist.get_world_size() + self.mesh_rank = MeshRank( + dp=mpu.get_data_parallel_rank(), + sp=mpu.get_context_parallel_rank(), + tp=mpu.get_tensor_model_parallel_rank(), + pp=mpu.get_pipeline_model_parallel_rank(), + world_size=self._world_size, + dp_size=mpu.get_data_parallel_world_size(), + pp_size=mpu.get_pipeline_model_parallel_world_size(), + ) + + def get_mesh_rank(self): + return self.mesh_rank + + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + def print(self, *msg): + """Print only on rank 0""" + if dist.get_rank() == 0: + logger.info(*msg) + + @staticmethod + def _get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + def get_ray_node_id(self): + return ray.get_runtime_context().get_node_id() + + @staticmethod + def get_rng_state(): + """Get current RNG state for reproducibility""" + rng_state = { + "cpu": torch.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + + # Only save CUDA RNG state if CUDA is available and being used + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + rng_state["cuda"] = torch.cuda.get_rng_state() + + return rng_state + + @staticmethod + def load_rng_state(rng_state): + """Load RNG state for reproducibility""" + torch.set_rng_state(rng_state["cpu"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + + # Only restore CUDA RNG state if it was saved and CUDA is available + if ( + "cuda" in rng_state + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + torch.cuda.set_rng_state(rng_state["cuda"]) + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + def destroy_worker_process_group(self): + mpu.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Clear stale env vars + for env_var in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]: + if env_var in os.environ: + del os.environ[env_var] + + def reinit_model_after_recovery(self): + """Re-initialize model and optimizer after process group recovery. + + This is needed because the model and optimizer were created with the old + process group and still have references to old NCCL communicators. + + We need to fully reinitialize the provider and model to ensure they use + the new process group. + """ + if not hasattr(self, "_model_path") or self._model_path is None: + # Fall back to cfg.model if _model_path not set + if hasattr(self.cfg, "model"): + model_path = self.cfg.model + else: + logger.warning("No model path found, cannot re-initialize model") + return + else: + model_path = self._model_path + + num_training_steps = getattr(self, "_num_training_steps", 1e9) + + logger.info("Re-initializing model components after process group recovery...") + + # Re-initialize the bridge and provider with the new process group + # This ensures all NCCL communicators are created fresh + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # Recreate the DDP-wrapped module with the new process group + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + # Recreate optimizer with the new process group + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + # Recreate scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # Recreate model wrapper + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # Re-normalize mini batch size with new world size + self._normalize_mini_batch_size() + + logger.info("Model components re-initialized successfully") + + def update_world_size(self, new_world_size: int): + """Update the world_size stored in the actor.""" + self._world_size = new_world_size + os.environ["WORLD_SIZE"] = str(new_world_size) + + def update_rank(self, new_rank: int): + """Update the rank stored in the actor.""" + self._rank = new_rank + os.environ["RANK"] = str(new_rank) + + def update_master_addr_port(self, master_addr: str, master_port: int): + """Update the master address and port for process group initialization.""" + self._master_addr = master_addr + self._master_port = master_port + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + def _normalize_mini_batch_size(self): + """ + Normalize mini batch sizes to per-gpu mini batch sizes. + """ + if not hasattr(self, "mesh_rank") or self.mesh_rank is None: + raise RuntimeError( + "mesh_rank must be initialized before calling _normalize_mini_batch_size()" + ) + + dp_size = self.mesh_rank.dp_size + self.policy_mini_batch_size_per_gpu = self.cfg.mini_batch_size // dp_size + + def ppo_train(self, train_data) -> "TrainingOutputBatch": + """ + Overrides `PolicyWorkerBase.ppo_train` for megatron. + + Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the + worker MegatronModelWrapper.forward_backward_mini_batch method. + """ + dataloader = BatchIterator( + train_data, + sample_batch_size=self.cfg.micro_train_batch_size_per_gpu, + drop_last=False, + ) + + micro_batches_per_mini_batch = ( + self.policy_mini_batch_size_per_gpu + // self.cfg.micro_train_batch_size_per_gpu + ) + + self.optimizer.zero_grad() + pbar = tqdm( + dataloader, + desc="ppo train", + disable=not dist.get_rank() == 0, + ) + + micro_buffer = [] + for local_step, experience in enumerate(pbar): + experience.to_device(torch.cuda.current_device()) + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + } + ) + + if len(micro_buffer) == micro_batches_per_mini_batch: + # run mini-batch forward-backward and then one optimizer step + self.model.train() + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + seq_len = micro_buffer[0]["sequences"].shape[1] + micro_bsz = micro_buffer[0]["sequences"].shape[0] + + self.model.forward_backward_mini_batch( + micro_batches=micro_buffer, + seq_len=seq_len, + micro_batch_size=micro_bsz, + ) + + _, grad_norm, _ = self.optimizer.step() + self.scheduler.step(1) + self.optimizer.zero_grad() + + torch.distributed.barrier() + + def save_checkpoint(self, ckpt_dir: str): + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + node_local_rank = self.get_node_local_rank() + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + model = model[0] + if hasattr(model, "module"): + model = model.module + + # Create checkpoint directory if it doesn't exist. + if node_local_rank == 0: + io.makedirs(ckpt_dir, exist_ok=True) + + # All ranks wait for the checkpoint directory to be created before saving. + dist.barrier() + + # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. + sharded_state_dict = {} + model_sharded_state_dict = model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # Save RNG state. + sharded_state_dict["rng"] = self.get_rng_state() + + # Save the checkpoint across ranks in parallel. + save_strategy = get_default_save_sharded_strategy("torch_dist") + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + with io.local_work_dir(ckpt_dir) as work_dir: + # synchronous checkpointing for now + async_save_request = dist_checkpointing.save( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=work_dir, + sharded_strategy=save_strategy, + async_sharded_save=False, + validate_access_integrity=True, + ) + assert ( + async_save_request is None + ), "Async save is not yet supported for Megatron" + + dist.barrier() + ckpt_base.async_calls.close() + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + self.print(f"Checkpoint successfully saved to {ckpt_dir}") + + def load_checkpoint( + self, + ckpt_dir: str, + load_module_strict: bool = True, + load_optimizer_states: bool = True, + load_lr_scheduler_states: bool = True, + ): + if not ckpt_dir or not io.exists(ckpt_dir): + raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") + + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + unwrapped_model = model[0] + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + + # Extract sharded state dicts. + sharded_state_dict = {} + model_sharded_state_dict = unwrapped_model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer and load_optimizer_states: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler and load_lr_scheduler_states: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory + # this should be improved to download only the relevant shards for this actor to load + with io.local_read_dir(ckpt_dir) as read_dir: + # Load the checkpoint in parallel. + load_strategy = get_default_load_sharded_strategy(read_dir) + load_strategy = FullyParallelLoadStrategyWrapper( + load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=read_dir, + sharded_strategy=load_strategy, + ) + + # Load the model, optimizer, and scheduler state dicts. + assert ( + "model" in state_dict + ), f"Model state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + model[0].load_state_dict(state_dict["model"], strict=load_module_strict) + self.print("Loaded model state dict.") + + if optimizer and load_optimizer_states: + assert ( + "optimizer" in state_dict + ), f"Optimizer state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + optimizer.load_state_dict(state_dict["optimizer"]) + self.print("Loaded optimizer state dict.") + + if scheduler and load_lr_scheduler_states: + assert ( + "lr_scheduler" in state_dict + ), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + scheduler.load_state_dict(state_dict["lr_scheduler"]) + self.print("Loaded LR scheduler state dict.") + + # Load RNG state, if present. + if "rng" in state_dict: + self.load_rng_state(state_dict["rng"]) + + return ckpt_dir, {} + + def offload_to_cpu(self): + self.all_buffer_sizes = offload_megatron_grads_to_cpu(self.actor_module) + offload_megatron_model_to_cpu(self.actor_module) + offload_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def backload_to_gpu(self): + load_megatron_grads_to_gpu(self.actor_module) + load_megatron_model_to_gpu(self.actor_module) + load_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # model init and bridge from huggingface methods: + def init_configs( + self, + model_path, + megatron_config, + transformer_config, + bf16=True, + flash_attn=True, + ): + """ + Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer + """ + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend + transformer_config = asdict(transformer_config) + transformer_config["attention_backend"] = "flash" if flash_attn else "fused" + + bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) + provider = bridge.to_megatron_provider() + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = ( + megatron_config.pipeline_model_parallel_size + ) + provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 + provider.context_parallel_size = megatron_config.context_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = ( + megatron_config.expert_tensor_parallel_size + ) + provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 + provider.attention_backend = "flash" if flash_attn else "fused" + provider.variable_seq_lengths = True + provider.masked_softmax_fusion = True + provider.moe_token_dispatcher_type = "alltoall" + + for k, v in transformer_config.items(): + setattr(provider, k, v) + provider.finalize() + + self.provider = provider + self.bridge = bridge + self.tokenizer = tokenizer + + def make_megatron_module( + self, + wrap_with_ddp: bool = True, + ddp_config: Optional[Dict[str, Any]] = None, + bf16: bool = True, + ) -> List[nn.Module]: + """ + Creates a megatron GPTModel (optionally DDP wrapped) using the bridge. + """ + from megatron.core.distributed.distributed_data_parallel_config import ( + DistributedDataParallelConfig, + ) + + default_ddp_config = DistributedDataParallelConfig() + if wrap_with_ddp: + default_ddp_config.use_distributed_optimizer = True + if ddp_config is not None: + for k, v in ddp_config.items(): + setattr(default_ddp_config, k, v) + model = self.provider.provide_distributed_model( + ddp_config=default_ddp_config, wrap_with_ddp=wrap_with_ddp, bf16=bf16 + ) + return model + + def init_model(self, model_path, num_training_steps: int = 1e9): + """ + Initialize the model, optimizer, and scheduler for the policy worker. + """ + # Store model path for potential recovery + self._model_path = model_path + self._num_training_steps = num_training_steps + + # initialize the bridge and provider objects + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # wrap with DDP for training + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + if self._local_rank == 0 and not os.path.exists( + model_path + ): # if not local path, try downloading model weights from huggingface + snapshot_download(model_path) # will be no-op if already downloaded + torch.distributed.barrier() + + # create optimizer + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + self._normalize_mini_batch_size() + + # create scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # create worker model + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # NOTE: Set Megatron dist checkpoint async backend to persistent to avoid `os.fork()`-ing + # short-lived background workers, which does not work well with Ray. + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + + +class MegatronActorGroup: + """ + A group of distributed megatron actors + Functions start with 'async' should return list of object refs + + Args: + cfg: config object for workers + num_nodes (int): Number of nodes for this actor group. + num_gpus_per_node (int): Number of gpus for this actor group. + pg (PlacementGroup, optional): Placement group to schedule actor on. + If none, create new placement group automatically. Defaults to None. + num_gpus_per_actor (float, optional): Number of gpus allocated for each actor. + If < 1.0, multiple models can share same gpu. Defaults to 1. + """ + + def __init__( + self, + cfg, + num_nodes, + num_gpus_per_node, + pg: PlacementGroup, + bundle_indices: List[int], + num_gpus_per_actor: float = 1.0, + resources: Optional[Dict[str, float]] = None, + num_resources_per_node: Optional[int] = None, + ) -> None: + self.cfg = cfg + self._num_nodes = num_nodes + self._num_gpus_per_node = num_gpus_per_node + + # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html + self._resources = resources + self._num_resources_per_node = num_resources_per_node + + self._initiate_actors(pg, num_gpus_per_actor, bundle_indices) + + def _initiate_actors( + self, + pg: Optional[PlacementGroup], + num_gpus_per_actor: float, + bundle_indices: List[int], + ): + """Initialize Ray actors in the worker group. + + Args: + pg: The placement group for the worker group + num_gpus_per_actor: The number of gpus to allocate per actor. + """ + world_size = self._num_nodes * self._num_gpus_per_node + assert pg is not None, "placement group must be provided to MegatronActorGroup" + pg_data = placement_group_table(pg) + assert ( + len(pg_data["bundles"]) >= world_size + ), "the number of bundles in the shared placement group must be greater than or equal to the world size" + + # place master actor on the + master_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[0], + ), + ).remote( + world_size=world_size, + rank=0, + local_rank=0, + master_addr=None, + master_port=None, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + + self._actor_handlers = [master_actor] + # Create worker actors + if world_size > 1: + master_addr, master_port = ray.get( + master_actor.get_master_addr_port.remote() + ) + for rank in range(1, world_size): + local_rank = rank % self._num_gpus_per_node + + worker_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[rank], + ), + ).remote( + world_size=world_size, + rank=rank, + local_rank=local_rank, + master_addr=master_addr, + master_port=master_port, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + self._actor_handlers.append(worker_actor) + + def initiate_worker_process_group(self): + # Initialize process group + logger.info("Initializing process group for RayActorGroup") + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + logger.info("Initialized process group for RayActorGroup") + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) + + def async_init_model( + self, + *args, + **kwargs, + ) -> List[ObjectRef]: + """Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers. + + Returns: + A list of ray object refs. + """ + return [ + actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers + ] + + def async_run_ray_method( + self, dispatch_type: str, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors using specified dispatch type asynchronously. + + Args: + dispatch_type: Type of dispatch to use ("mesh" or "pass_through") + method_name: Name of the method to call on actors + *args: Positional arguments to pass to the method + **kwargs: Keyword arguments to pass to the method + + Returns: + List of object references + """ + dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type) + # validate the dispatch args to be sent to `.dispatch` + args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs) + + # Dispatch the method call + object_refs = dispatch_class.dispatch( + self.actor_infos, method_name, *args, **kwargs + ) + return object_refs + + def async_run_method_no_dispatch( + self, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors without dispatching.""" + return [ + getattr(handle, method_name).remote(*args, **kwargs) + for handle in self._actor_handlers + ] + + def _check_actor_alive(self, actor_handle) -> bool: + """Check if an actor is still alive by attempting to call a simple method.""" + try: + # Try to get a simple attribute or call a simple method with timeout + ray.get(actor_handle.get_mesh_rank.remote(), timeout=10) + return True + except Exception: + return False + + def recover_from_failure( + self, backup_actor_group: Optional["MegatronActorGroup"] = None + ): + """Recover from actor failures by removing dead actors and re-initializing process group.""" + logger.info("Starting recovery from actor failure...") + + # Filter out dead actors - both actor_infos and actor_handlers should be in sync + alive_actor_handlers = [] + num_dead_actors = 0 + dead_actor_ranks = [] + + for i, (actor_info, actor_handle) in enumerate( + zip(self.actor_infos, self._actor_handlers) + ): + if self._check_actor_alive(actor_info.handle): + alive_actor_handlers.append(actor_handle) + else: + logger.warning(f"Actor {i} is dead, removing from group") + num_dead_actors += 1 + dead_actor_ranks.append(i) + + if len(alive_actor_handlers) == 0: + raise RuntimeError("All actors are dead, cannot recover") + + if len(alive_actor_handlers) == len(self._actor_handlers): + logger.info("All actors are alive, no recovery needed") + return + + logger.info( + f"Recovering with {len(alive_actor_handlers)}/{len(self._actor_handlers)} actors" + ) + + self._actor_handlers = alive_actor_handlers + + # Destroy existing process groups on alive actors first + logger.info("Destroying old process groups...") + try: + ray.get( + [ + actor.destroy_worker_process_group.remote() + for actor in self._actor_handlers + ] + ) + except Exception as e: + logger.warning( + f"Some errors during process group destruction (may be expected): {e}" + ) + + # if backup actor group is provided, we pop idle actors from the backup actor group and insert them into the current actor group + if backup_actor_group is not None: + logger.info( + f"Popping {num_dead_actors} idle actors from backup actor group" + ) + idle_actor_handles = [ + backup_actor_group._actor_handlers.pop() for _ in range(num_dead_actors) + ] + # let's assume for now that the dead actors are contiguous in the actor group, so we insert the idle actors at the rank of the first dead actor + rank_to_insert = min(dead_actor_ranks) + logger.info(f"Inserting idle actors at rank {rank_to_insert}") + self._actor_handlers = ( + self._actor_handlers[:rank_to_insert] + + idle_actor_handles + + self._actor_handlers[rank_to_insert:] + ) + + # Re-initialize process group with remaining actors + # Update world_size and ranks to match the number of alive actors + new_world_size = len(self._actor_handlers) + + # Update world_size and reassign ranks sequentially (0, 1, 2, ...) + logger.info(f"Updating world_size to {new_world_size} and reassigning ranks...") + update_tasks = [] + for new_rank, actor in enumerate(self._actor_handlers): + update_tasks.append(actor.update_world_size.remote(new_world_size)) + update_tasks.append(actor.update_rank.remote(new_rank)) + ray.get(update_tasks) + + # get master address and a new free port for the new process group + master_addr, _ = ray.get(self._actor_handlers[0].get_master_addr_port.remote()) + master_port = ray.get(self._actor_handlers[0]._get_free_port.remote()) + logger.info(f"Using master_addr={master_addr}, master_port={master_port}") + + # Update master address/port in all actors + ray.get( + [ + actor.update_master_addr_port.remote(master_addr, master_port) + for actor in self._actor_handlers + ] + ) + + # Re-initialize process groups with new world_size and ranks + logger.info( + f"Re-initializing process group with world_size={new_world_size}..." + ) + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + + # Re-initialize model and optimizer with the new process group + # This is critical because they were created with the old process group + logger.info("Re-initializing model and optimizer with new process group...") + ray.get( + [ + actor.reinit_model_after_recovery.remote() + for actor in self._actor_handlers + ] + ) + + # Update actor_infos with new mesh ranks + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Recovery complete. New mesh ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) diff --git a/megatron_ray_fault_tolerant/megatron_model_utils.py b/megatron_ray_fault_tolerant/megatron_model_utils.py new file mode 100644 index 0000000..bc5be4a --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_utils.py @@ -0,0 +1,442 @@ +# Utils ported from NeMo-Aligner by way of NeMo-RL +# https://github.com/NVIDIA-NeMo/RL/blob/9301d36cbf847212430b84a27cfe6990f773b7cf/nemo_rl/distributed/model_utils.py#L4 +# The original copyright is reproduced below: + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + + +@torch.no_grad() +def _compute_distributed_log_softmax( + vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup +) -> torch.Tensor: + """Compute a stable distributed log softmax across tensor parallel workers. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP] + where TP is the tensor parallel size. + group (torch.distributed.ProcessGroup): Process group for the all-reduce operations. + + Returns: + torch.Tensor: Log softmax output with the same shape as input, but values represent + log probabilities normalized across the full vocabulary dimension. + """ + logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max + + sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() + + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) + + +class DistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) + softmax_output = log_probs.exp() + + log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) + log_probs[target_mask] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(softmax_output, target_mask, masked_target) + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + softmax, target_mask, masked_target = ctx.saved_tensors + + if softmax.ndim == 3: + B, S, V = softmax.shape + + # skip `torch.nn.functional.one_hot` + row = ( + torch.arange(B, device=softmax.device) + .view(-1, 1) + .expand(-1, S) + .reshape(-1) + ) + col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) + flat_idx = (row * S + col) * V + + flat_chosen = flat_idx.masked_select( + ~target_mask.reshape(-1) + ) + masked_target.masked_select(~target_mask) + + # `neg` is zero-copy + grad_input = softmax.neg() + grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) + + grad_output_selected = grad_output.masked_select(~target_mask) + grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + else: + V = softmax.size(-1) + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V + ) + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +class ChunkedDistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + The log probabilities computation is chunked in the sequence dimension + to mitigate GPU OOM (especially during backward pass). + In addition, logits casting from float16 or bfloat16 -> float32 is performed + inside the chunk loop to avoid materializing a whole float32 logits tensor. + + Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + + log_probs = torch.gather( + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(vocab_parallel_logits, target_mask, masked_target) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + vocab_parallel_logits, target_mask, masked_target = ctx.saved_tensors + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + + partition_vocab_size = int(vocab_parallel_logits.shape[-1]) + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + + all_grad_input = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + softmax_output = softmax_output.exp() + + # 1 if it's the chosen log prob, 0 otherwise + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + + grad_input = is_chosen.float().sub_(softmax_output) + + grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + + all_grad_input.append(grad_input) + + grad_input = torch.cat(all_grad_input, dim=1) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +def from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, +) -> torch.Tensor: + """Get log probabilities from TP+CP sharded vocab logits. + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] + where TP is the tensor parallel size. + target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. + NOTE: Must be the unmodified targets as this function will shift them internally. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + + Returns: + torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. + The sequence dimension is reduced by 1 due to the target shifting. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + """ + target = target.roll(shifts=-1, dims=-1) + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + pad_len = 0 + # if cp_size > 1: + # Pad the targets to local size * cp_size + pad_len = vocab_parallel_logits.shape[1] * cp_size - target.shape[1] + if pad_len > 0: + target = torch.nn.functional.pad(target, (0, pad_len), value=0) + + # Shard the targets by context parallelism + cp_rank = torch.distributed.get_rank(cp_group) + target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) + + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() + + if cp_size > 1: + # we need to gather the logits by context parallelism + logprobs = allgather_cp_sharded_tensor( + logprobs, cp_group, seq_dim=1 + ) # , unpadded_seqlen=target.shape[1]) + + if pad_len > 0: + logprobs = logprobs[:, :-pad_len] + + return logprobs[:, :-1] + + +def _get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1, +) -> torch.Tensor: + """Get tokens on this context parallelism rank. + + Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + + Args: + input_ids: Input token IDs [seq_length, ] + cp_rank: Context parallelism rank + cp_size: Context parallelism size + + Returns: + Tokens on this context parallelism rank [1, seq_length // cp_size] + """ + if cp_size == 1: + return input_ids + + # load balance for causal attention + shard_size = input_ids.shape[seq_dim] // (cp_size * 2) + shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1) + + # Create slices for each dimension + slices = [slice(None)] * input_ids.dim() + ids_chunks = [] + + for ind in shard_inds: + slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) + ids_chunks.append(input_ids[slices]) + + ids = torch.cat(ids_chunks, dim=seq_dim) + return ids + + +def allgather_cp_sharded_tensor( + tensor, cp_group, seq_dim=1 +): # , unpadded_seqlen=None): + return AllGatherCPTensor.apply(tensor, cp_group, seq_dim) # , unpadded_seqlen) + + +class AllGatherCPTensor(torch.autograd.Function): + def forward( + ctx, tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1 + ): # , unpadded_seqlen: Optional[int] = None): + cp_size = torch.distributed.get_world_size(cp_group) + cp_rank_chunks = [] + for _ in range(cp_size): + cp_rank_chunks.append(torch.empty_like(tensor)) + + torch.distributed.all_gather( + tensor_list=cp_rank_chunks, tensor=tensor, group=cp_group + ) + + # undo the CP load balancing chunking + tensor_chunks = [] + for logit_chunk in cp_rank_chunks: + tensor_chunks.extend(torch.chunk(logit_chunk, chunks=2, dim=seq_dim)) + + chunk_indices = [] + for cp_rank in range(cp_size): + chunk_indices.append(cp_rank) + chunk_indices.append(2 * cp_size - cp_rank - 1) + + chunks_and_indices = list(zip(tensor_chunks, chunk_indices)) + chunks_and_indices = sorted(chunks_and_indices, key=lambda tup: tup[1]) + ret_tensor = [chunk for chunk, _ in chunks_and_indices] + ret_tensor = torch.cat(ret_tensor, dim=seq_dim) + + ctx.seq_dim = seq_dim + ctx.cp_group = cp_group + # ctx.unpadded_seqlen = unpadded_seqlen + + return ret_tensor + + def backward(ctx, grad_output): + cp_size = torch.distributed.get_world_size(ctx.cp_group) + cp_rank = torch.distributed.get_rank(ctx.cp_group) + torch.distributed.all_reduce(grad_output, group=ctx.cp_group) + + # chunk the seqdim in 2*cp chunks, and select with a CP load balanced indexing + seq_dim = ctx.seq_dim + # if ctx.unpadded_seqlen is not None: + # # Zero out grad_output along the seq_dim after unpadded_seqlen + # slicer = [slice(None)] * grad_output.dim() + # slicer[seq_dim] = slice(ctx.unpadded_seqlen, None) + # grad_output[tuple(slicer)] = 0 + + grad_output = grad_output.view( + *grad_output.shape[0:seq_dim], + 2 * cp_size, + grad_output.shape[seq_dim] // (2 * cp_size), + *grad_output.shape[(seq_dim + 1) :], + ) + + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + + grad_input = grad_output.index_select(seq_dim, index) + grad_input = grad_input.view( + *grad_input.shape[0:seq_dim], -1, *grad_input.shape[(seq_dim + 2) :] + ) + + return grad_input, None, None # , None diff --git a/megatron_ray_fault_tolerant/megatron_model_wrapper.py b/megatron_ray_fault_tolerant/megatron_model_wrapper.py new file mode 100644 index 0000000..07e885d --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_wrapper.py @@ -0,0 +1,171 @@ +from typing import Optional, List +from functools import partial +import torch +import torch.nn as nn + +from megatron.core.pipeline_parallel import get_forward_backward_func +import megatron.core.parallel_state as mpu +from megatron.core.distributed import finalize_model_grads + +from megatron_model_utils import from_parallel_logits_to_logprobs +from megatron_utils import ( + get_model_config, + make_batch_generator, + preprocess_packed_seqs, + postprocess_packed_seqs, +) +from utils import ppo_policy_loss + + +class MegatronModelWrapper: + def __init__( + self, + config, + actor_module: List[nn.Module], + actor_optimizer: Optional[torch.optim.Optimizer] = None, + ): + self.cfg = config + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + + config = get_model_config(self.actor_module[0]) + # This is set to None by default: https://github.com/NVIDIA/Megatron-LM/blob/07b22a05136a3cb08ece05f7de38cf6aeeb165fb/megatron/core/model_parallel_config.py#L95 + # use the build in finalize_model_grads function to all reduce gradients across parallelism dimensions + config.finalize_model_grads_func = finalize_model_grads + + def train(self): + [module.train() for module in self.actor_module] + + def eval(self): + [module.eval() for module in self.actor_module] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward_backward_mini_batch( + self, + micro_batches: List[dict], + seq_len: int, + micro_batch_size: int, + temperature: float = 1.0, + ) -> List[dict]: + """ + Run forward-backward over a full mini-batch consisting of multiple micro-batches. + + Args: + micro_batches: A list of micro-batch dicts. Each dict must contain keys: + "sequences", "attention_mask", "position_ids", "num_actions", + "old_action_log_probs", "base_action_log_probs", "advantages", + "loss_mask". + seq_len: Sequence length (tokens) per sample (assumed same across micros after padding). + micro_batch_size: Micro-batch size per forward pass. + temperature: Optional temperature for logits scaling. + + Returns: + List[dict]: one metrics dict per micro-batch in order. + """ + forward_backward_func = get_forward_backward_func() + + def loss_func(logits, data): + sequences = data["sequences"] + num_actions = data["num_actions"] + old_action_log_probs = data["old_action_log_probs"] + advantages = data["advantages"] + loss_mask = data["loss_mask"] + + tp_grp = mpu.get_tensor_model_parallel_group() + tp_rank = mpu.get_tensor_model_parallel_rank() + + # temperature normalization + if temperature != 1.0: + logits.div_(temperature) + + token_logprobs = from_parallel_logits_to_logprobs( + logits, + sequences, + vocab_start_index=tp_rank * logits.shape[-1], + vocab_end_index=(tp_rank + 1) * logits.shape[-1], + tp_group=tp_grp, + inference_only=False, + cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` + chunk_size=None, + ) + + action_log_probs = token_logprobs[:, -num_actions:] + + # policy loss should be calculated based on the selected token logprobs + policy_loss, clip_ratio = ppo_policy_loss( + action_log_probs, + old_action_log_probs, + advantages, + config=self.cfg, + loss_mask=loss_mask, + ) + + # no kl loss or entropy loss + loss = policy_loss + + metrics = { + "policy_loss": policy_loss.detach().item(), + "ppo_clip_ratio": clip_ratio, + } + return loss, metrics + + def forward_step(batch_iter, model): + batch = next(batch_iter) + + sequences = batch["sequences"] + attention_mask = batch["attention_mask"].to(bool) + + new_sequences, packed_seq_params = preprocess_packed_seqs( + sequences, + attention_mask, + pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True), + ) + new_attention_mask = None + new_position_ids = None + + outputs = model( + new_sequences, + new_position_ids, + new_attention_mask, + packed_seq_params=packed_seq_params, + ) + + outputs = postprocess_packed_seqs( + outputs, + packed_seq_params, + attention_mask, + micro_batch_size, + seq_len, + post_process=mpu.is_pipeline_last_stage(ignore_virtual=True), + ) + + return outputs, partial(loss_func, data=batch) + + # batch should be a list of micro-batches + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.actor_module) + ) + + metrics_list = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=len(micro_batches), + seq_length=seq_len, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + # broadcast metrics to all pp ranks + if not mpu.is_pipeline_last_stage(ignore_virtual=True): + metrics_list = [None] * len(micro_batches) + with torch.no_grad(): + torch.distributed.broadcast_object_list( + metrics_list, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + + return metrics_list diff --git a/megatron_ray_fault_tolerant/megatron_utils.py b/megatron_ray_fault_tolerant/megatron_utils.py new file mode 100644 index 0000000..4a0f015 --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_utils.py @@ -0,0 +1,465 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron_utils.py#L4 +# https://github.com/volcengine/verl/blob/dfa3933ac44b545fca1f6a8519fd07394a2cde1c/verl/models/mcore/util.py +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import gc +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.transformer.module import Float16Module +from megatron.core.optimizer import ChainedOptimizer +from megatron.core import parallel_state as mpu +from megatron.core.utils import get_attr_wrapped_model +from megatron.core.packed_seq_params import PackedSeqParams + +ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def make_batch_generator(batches, vpp_size): + """ + Creates a batch generator suitable for Megatron pipeline parallelism, + handling virtual pipeline parallelism (VPP). + + If VPP is used (vpp_size > 1), it duplicates the batch iterator for each + virtual pipeline stage. Otherwise, it returns a single iterator. + + Args: + batches: An iterable (e.g., list) of micro-batches. + vpp_size (int): The virtual pipeline model parallel size. + + Returns: + An iterator or a list of iterators over the micro-batches. + """ + if vpp_size > 1: + # has vpp + batch_generator = [batches] * vpp_size # number of vpp chunks + batch_generator = [iter(b) for b in batch_generator] + else: + # no vpp + batch_generator = iter(batches) + return batch_generator + + +@torch.no_grad() +def offload_megatron_grads_to_cpu(models): + all_buffer_sizes = [] + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + buffer_sizes = [] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.grad_data.storage().size() > 0: + buffer_sizes.append(buffer.grad_data.storage().size()) + buffer.grad_data.storage().resize_(0) + all_buffer_sizes.append(buffer_sizes) + else: + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + return all_buffer_sizes + + +@torch.no_grad() +def load_megatron_grads_to_gpu(models, buffer_sizes): + for i, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for j, buffers in enumerate(model_chunk_all_buffers): + for buffer in buffers: + buffer.grad_data.storage().resize_(buffer_sizes[i][j]) + buffer.grad_data.zero_() + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = ( + buffer.param_data.data.cpu().pin_memory() + ) + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert ( + buffer.param_data_size + == buffer.param_data.cpu_data.storage().size() + ) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_model_to_gpu(models): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_( + buffer.param_data.cpu_data, non_blocking=True + ) + else: + # we need this for ref module + device_id = torch.cuda.current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to("cpu", non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = torch.cuda.current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, "_move_new_state_to_right_device"): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to( + torch.cuda.current_device(), non_blocking=True + ) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = ( + seqlens_in_batch.tolist() + ) # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = ( + seqlens_in_batch_padded.tolist() + ) # lengths after padding + cu_seqlens_padded_cpu: list[int] = ( + cu_seqlens_padded.tolist() + ) # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros( + shape, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[ + i, attention_mask[i] + ] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[ + start_idx + half_seqlen : start_idx + half_seqlen + remain_len + ] = d[remain_start:remain_end] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = ( + attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + ) + + shape = [batch_size, seq_len] + list( + output.shape[2:] + ) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather( + output_list, output.detach(), group=mpu.get_context_parallel_group() + ) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[ + packed_start_idx + + half_seqlen : packed_start_idx + + s_len_padded_chunk + ], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[ + s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen + ] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) diff --git a/megatron_ray_fault_tolerant/optimizer.py b/megatron_ray_fault_tolerant/optimizer.py new file mode 100644 index 0000000..f243397 --- /dev/null +++ b/megatron_ray_fault_tolerant/optimizer.py @@ -0,0 +1,103 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron/optimizer.py#L4 +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.optimizer import OptimizerConfig +from megatron.core.optimizer import ( + get_megatron_optimizer as get_megatron_optimizer_native, +) +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler + + +def init_megatron_optim_config(optim_config) -> OptimizerConfig: + optim_args = { + "optimizer": optim_config.get("optimizer", "adam"), + "lr": optim_config.get("lr", 1.0e-6), + "min_lr": optim_config.get("min_lr", 0.0), + "clip_grad": optim_config.get("max_grad_norm", 1.0), + "weight_decay": optim_config.get("weight_decay", 0.01), + "bf16": True, + "params_dtype": torch.bfloat16, + "use_distributed_optimizer": True, + } + + config = OptimizerConfig(**optim_args) + return config + + +def get_megatron_optimizer( + model, + config: OptimizerConfig, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, +): + # Base optimizer. + return get_megatron_optimizer_native( + config=config, + model_chunks=model, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + ) + + +def get_megatron_optimizer_param_scheduler( + optimizer, + config, + num_training_steps: int = 1e9, # default to a large number for constant lr/wd +): + """ + Get the optimizer parameter scheduler for Megatron. + """ + lr_warmup_steps = config.get("num_warmup_steps", 0) + if config.get("lr_decay_steps", None) is None: + lr_decay_steps = num_training_steps + if config.get("lr_warmup_steps_ratio", None) is not None and ( + config.get("lr_warmup_steps", None) is None + or config.get("lr_warmup_steps", 0) <= 0 + ): + lr_warmup_steps = int(config.get("lr_warmup_steps_ratio", 0.0) * lr_decay_steps) + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=config.get("lr_warmup_init", 0.0), + max_lr=config.get("lr", 1.0e-6), + min_lr=config.get("min_lr", 0.0), + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style="constant", + start_wd=config.get("weight_decay", 0.01), + end_wd=config.get("weight_decay", 0.01), + wd_incr_steps=num_training_steps, + wd_incr_style="constant", + use_checkpoint_opt_param_scheduler=False, + override_opt_param_scheduler=True, + wsd_decay_steps=None, + lr_wsd_decay_style="exponential", + ) + + return opt_param_scheduler + + +def get_megatron_last_lr(optimizer): + """ + Get the last learning rate from the optimizer parameter scheduler. + """ + return optimizer.param_groups[0]["lr"] diff --git a/megatron_ray_fault_tolerant/pyproject.toml b/megatron_ray_fault_tolerant/pyproject.toml new file mode 100644 index 0000000..51be3c3 --- /dev/null +++ b/megatron_ray_fault_tolerant/pyproject.toml @@ -0,0 +1,98 @@ +[project] +name = "ray-ft" +version = "0.0.1" +description = "ray" +authors = [ + {name = "ray", email = "ray@gmail.com"} +] +license = {text = "MIT"} +readme = "README.md" +requires-python = "==3.12.*" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +dependencies = [ + "ninja", + "tensorboard", + "func_timeout", + "transformers>=4.51.0", + "torchdata", + "omegaconf", + "ray==2.51.0", + "peft", + "debugpy==1.8.0", + "hf_transfer", + "wandb", + "datasets==4.0.0", + "flash-attn", + "polars", + "loguru", + "jaxtyping", + "s3fs", + # Make sure to change the flash attention source (under tool.uv.sources) above to a compatible version (<= 2.7.4.post1) for TransformerEngine==2.5.0 + # https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl + # For single node: build transformer-engine separately first, and uncomment the transformer-engine library import below + # uv pip install "torch==2.7.1" + # uv pip install "nvidia-cudnn-cu12>=9.3" + # export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" + # export CPATH="$CUDNN_PATH/include:${CPATH:-}" + # export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}" + # uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose + # "transformer-engine[pytorch]==2.5.0", + "transformer-engine[pytorch]==2.7.0", + "flash-attn==2.7.4.post1", + "vllm==0.10.1.1", + "torch==2.7.1", + "flashinfer-python", + "torchvision", + "megatron-bridge==0.1.0rc4", + "megatron-core==0.14.0", +] + +[tool.uv] +required-version = ">=0.8.10" +no-build-isolation-package = [ + "transformer-engine-torch", + "transformer-engine", +] + +[tool.uv.extra-build-dependencies] +flash-attn = [{requirement = "torch", match-runtime = true}] +transformer-engine = [{ requirement = "torch", match-runtime = true }, "build_tools"] +transformer-engine-torch = [{ requirement = "torch", match-runtime = true }, "build_tools"] + +[tool.uv.extra-build-variables] +flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} + +[tool.uv.sources] +torch = { index = "pytorch-cu128" } +torchvision = { index = "pytorch-cu128" } +# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run. +# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang +flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" } +flashinfer-python = [ + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" }, + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" } +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[[tool.uv.index]] +name = "flashinfer-cu128" +url = "https://flashinfer.ai/whl/cu128" +explicit = true + +[tool.setuptools] +include-package-data = true + +[tool.pytest.ini_options] +addopts = "-v -s" +testpaths = [ + "tests", +] \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/run.sh b/megatron_ray_fault_tolerant/run.sh new file mode 100755 index 0000000..c9455a3 --- /dev/null +++ b/megatron_ray_fault_tolerant/run.sh @@ -0,0 +1 @@ +anyscale job submit -f job.yaml \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/training_batch.py b/megatron_ray_fault_tolerant/training_batch.py new file mode 100644 index 0000000..eacdbe6 --- /dev/null +++ b/megatron_ray_fault_tolerant/training_batch.py @@ -0,0 +1,371 @@ +"""Defines interfaces for training data.""" + +from typing import TypedDict, Dict, Any, List, Optional, Generic, TypeVar +import torch +from jaxtyping import Float, Integer +import pickle +import io + +DictType = TypeVar("DictType") + + +# Class inspired by `TensorDict` but is much simpler. +class TensorBatch(dict, Generic[DictType]): + """Base class for training batches + + This defines a generic container for a batch of training data (inputs or outputs). + Consists of a dictionary of tensors along with some metadata. + """ + + metadata: Optional[Dict[str, Any]] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._batch_size = None + self._device = None + self._check_consistency() + + def select( + self, keys: List[str], metadata_keys: Optional[List[str]] = None + ) -> "TensorBatch[DictType]": + """Select a subset of the data batch. + + Args: + keys: The keys to select + metadata_keys: The metadata keys to select + + Returns: + A new `TensorBatch` object with the selected keys and metadata + """ + selected_batch_data = {} + for key in keys: + selected_batch_data[key] = self[key] + selected_metadata = {} + if metadata_keys is None: + selected_metadata = self.metadata + else: + selected_metadata = {} + for key in metadata_keys: + selected_metadata[key] = self.metadata[key] + new_batch = self.__class__(selected_batch_data) + new_batch.metadata = selected_metadata + return new_batch + + def _check_consistency(self): + """Check consistency of all present fields""" + keys = list(self.keys()) + if len(keys) == 0: + return + + batch_size = len(self[keys[0]]) + self._batch_size = batch_size + for key in keys: + value = self[key] + if value is None: + continue + self._device = value.device if self._device is None else self._device + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + if len(value) != batch_size: + raise ValueError(f"Batch size mismatch in {key}") + if value.device != self._device: + raise ValueError( + f"Device mismatch in {key}. Expected {self._device}, got {value.device}" + ) + + def __getitem__(self, index) -> "TensorBatch[DictType]": + if isinstance(index, slice): + return self.slice(index.start, index.stop, index.step) + elif isinstance(index, int): + return self.slice(index, index + 1) + else: + return super().__getitem__(index) + + def __setitem__(self, key: str, value: Optional[torch.Tensor]) -> None: + if value is None: + super().__setitem__(key, value) + return + + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + + if ( + hasattr(self, "_batch_size") + and self._batch_size is not None + and len(value) != self._batch_size + ): + raise ValueError( + f"Batch size mismatch in {key}. Expected tensor to be of size {self._batch_size}, got {len(value)}." + ) + + super().__setitem__(key, value) + + if hasattr(self, "_batch_size") and self._batch_size is None: + self._batch_size = len(value) + + def to( + self, + device: torch.device = None, + dtype: torch.dtype = None, + *, + non_blocking: bool = False, + ) -> "TensorBatch": + """Move tensors to device and/or cast to dtype. + + Args: + device: The device to move the tensors to + dtype: The dtype to cast the tensors to + non_blocking: Whether the operation should be non-blocking + """ + for key, value in self.items(): + if value is None: + continue + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.to(device, dtype, non_blocking=non_blocking) + return self + + def contiguous(self) -> "TensorBatch": + """Make the tensors contiguous""" + for key, value in self.items(): + if value is None: + continue + # some of these asserts are not needed, but it's kept for type safety + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.contiguous() + return self + + @property + def batch_size(self) -> int: + """Batch size for the tensors""" + return self._batch_size + + @property + def device(self) -> torch.device: + """Get the device for the tensors""" + return self._device + + def __getstate__(self): + """Serialize the `TensorBatch` object for pickle protocol""" + self.contiguous() + if self._device is not None: + assert self._device == torch.device( + "cpu" + ), "Tensors must be on CPU before serialization" + batch_dict = {} + for key, value in self.items(): + buffer = io.BytesIO() + torch.save(value, buffer) + batch_dict[key] = buffer.getvalue() + + return { + "batch_dict": batch_dict, + "batch_size": self._batch_size, + "device": self._device, + "metadata": self.metadata, + } + + def __setstate__(self, state): + """Deserialize the `TensorBatch` object and load it into memory""" + for key, value in state["batch_dict"].items(): + buffer = io.BytesIO(value) + self[key] = torch.load(buffer) + + self._batch_size = state["batch_size"] + self._device = state["device"] + self.metadata = state["metadata"] + self._check_consistency() + return self + + def repeat(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def repeat_interleave(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat_interleave(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]: + """Split into smaller chunks""" + chunks = [] + for i in range(0, self.batch_size, chunk_size): + chunk_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + chunk_data[key] = value[i : i + chunk_size] + else: + raise ValueError( + f"Unsupported type {type(value)} for key {key}" + ) + else: + # `None` values are not chunked + chunk_data[key] = value + chunk = self.__class__(chunk_data) + chunk.metadata = self.metadata + chunks.append(chunk) + return chunks + + def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]": + """Slice the data batch. + + Args: + start: The start index + end: The end index + step: The step size + + Returns: + A new `TensorBatch` object with the view of the specified slice. + """ + slice_obj = slice(start, end, step) + sliced_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + sliced_data[key] = value[slice_obj] + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not sliced + sliced_data[key] = value + sliced_batch = self.__class__(sliced_data) + sliced_batch.metadata = self.metadata + return sliced_batch + + def save(self, path: str): + """Save the data to a pickle file""" + with open(path, "wb") as f: + pickle.dump(self, f) + + def load(self, path: str): + """Load the data from a pickle file""" + with open(path, "rb") as f: + return pickle.load(f) + + @classmethod + def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]": + """Concatenate shards. + + Args: + shards: The list of `TensorBatch` objects to cat + + Returns: + A new `TensorBatch` object with the concatenated data + """ + cat_data = {} + assert len(shards) > 0, "Cannot cat an empty list of shards" + for key, value in shards[0].items(): + if value is not None: + if isinstance(value, torch.Tensor): + cat_data[key] = torch.cat([shard[key] for shard in shards]) + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not cat'd + cat_data[key] = value + metadata = shards[0].metadata + cat_batch = cls(cat_data) + cat_batch.metadata = metadata + return cat_batch + + def __len__(self) -> int: + """Length of the batch. + + Note that this is the same as the batch size rather than the number of keys in the batch. + """ + return self._batch_size + + def __eq__(self, other: Any) -> bool: + """Check if two `TensorBatch` objects are equal""" + if not isinstance(other, TensorBatch): + return False + if self.metadata != other.metadata: + return False + if len(self) != len(other): + return False + if len(self.items()) != len(other.items()): + return False + for k, v in self.items(): + if k not in other or not torch.equal(v, other[k]): + return False + return True + + def __str__(self) -> str: + """String representation of the `TensorBatch` object""" + return f"TensorBatch(batch_size={self.batch_size}, device={self.device}, metadata={self.metadata}), items={self.items()}" + + def __repr__(self) -> str: + """String representation of the `TensorBatch` object""" + return self.__str__() + + +class TrainingInput(TypedDict, total=False): + """Schema for training input batch""" + + sequences: Integer[torch.Tensor, "batch_size seq_len"] + attention_mask: Integer[torch.Tensor, "batch_size seq_len"] + loss_mask: Integer[torch.Tensor, "batch_size seq_len"] + response_mask: Integer[torch.Tensor, "batch_size seq_len"] + action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + values: Optional[Float[torch.Tensor, "batch_size seq_len"]] + returns: Float[torch.Tensor, "batch_size seq_len"] + advantages: Float[torch.Tensor, "batch_size seq_len"] + kl: Float[torch.Tensor, "batch_size seq_len"] + rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + + +class TrainingInputBatch(TensorBatch[TrainingInput]): + """Training input data""" + + pass + + +class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]): + """Training output data""" + + pass diff --git a/megatron_ray_fault_tolerant/utils.py b/megatron_ray_fault_tolerant/utils.py new file mode 100644 index 0000000..e07689d --- /dev/null +++ b/megatron_ray_fault_tolerant/utils.py @@ -0,0 +1,286 @@ +import ray +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +import torch +from typing import Any, Optional, Dict, List, Union, Tuple +from dataclasses import dataclass +from jaxtyping import Integer, Float +import math +from transformers import AutoTokenizer + + +from training_batch import TrainingInputBatch + +BasicType = Union[int, float, str, bool] + + +@ray.remote(num_gpus=1) +class InfoActor: + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + +def get_reordered_bundle_indices(pg: PlacementGroup): + """ + Get the reordered bundle indices for a placement group to ensure adjacent ranks are on the same node when possible + """ + pg_data = placement_group_table(pg) + num_bundles = len(pg_data["bundles"]) + bundle_to_node_ids = pg_data["bundles_to_node_id"] + + # use info actor to get the GPU id + info_actors = [] + for i in range(num_bundles): + info_actors.append( + InfoActor.options( + num_cpus=0.01, # set both num_cpus and num_gpus to be small values to enable assignment in colocated case + num_gpus=0.01, + resources=None, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=i, + ), + ).remote() + ) + + gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors]) + for actor in info_actors: + ray.kill(actor) + + # original index, node_id, gpu_id + bundle_infos = [(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)] + pg_reordered_bundle_indices = [ + bundle_info[0] + for bundle_info in sorted(bundle_infos, key=lambda x: (x[1], x[2])) + ] # sort by node_id, then gpu_id + return pg_reordered_bundle_indices + + +def to(tensor: Union[torch.Tensor, List[torch.Tensor], BasicType], device): + if isinstance(tensor, list): + return [to(t, device) for t in tensor] + elif isinstance(tensor, torch.Tensor): + return tensor.to(device) + else: + return tensor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + base_action_log_probs: (B, A) + values: (B, A) + returns: (B, A) + advatanges: (B, A) + attention_mask: (B, S) + action_mask: (B, A) + kl: (B, A) + + "A" is the number of actions/ response length. + """ + + sequences: Integer[torch.Tensor, "batch seq_len"] + action_log_probs: Float[torch.Tensor, "batch response_len"] + base_action_log_probs: Optional[Float[torch.Tensor, "batch response_len"]] + values: Optional[Float[torch.Tensor, "batch response_len"]] + returns: Optional[Float[torch.Tensor, "batch response_len"]] + advantages: Optional[Float[torch.Tensor, "batch response_len"]] + attention_mask: Optional[Integer[torch.LongTensor, "batch seq_len"]] + loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] + action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + num_actions: int + info: Optional[dict] + kl: Optional[Float[torch.Tensor, "batch response_len"]] = None + metadata: Optional[Dict[str, Any]] = None + + @torch.no_grad() + def to_device(self, device: torch.device) -> None: + self.sequences = to(self.sequences, device) + self.action_log_probs = to(self.action_log_probs, device) + if self.base_action_log_probs is not None: + self.base_action_log_probs = to(self.base_action_log_probs, device) + if self.values is not None: + self.values = to(self.values, device) + if self.returns is not None: + self.returns = to(self.returns, device) + if self.advantages is not None: + self.advantages = to(self.advantages, device) + if self.attention_mask is not None: + self.attention_mask = to(self.attention_mask, device) + if self.loss_mask is not None: + self.loss_mask = to(self.loss_mask, device) + if self.action_mask is not None: + self.action_mask = to(self.action_mask, device) + if self.rollout_logprobs is not None: + self.rollout_logprobs = to(self.rollout_logprobs, device) + + +class BatchIterator: + """A simple iterator to yield micro batches of data from the training batch.""" + + def __init__( + self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False + ): + self.data = data + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = ( + int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + ) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> Experience: + try: + batch = next(self._iter) + exp = self.batch_to_experience(batch) + return exp + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + @staticmethod + def batch_to_experience(batch: TrainingInputBatch): + exp = Experience( + sequences=batch["sequences"], + action_log_probs=batch["action_log_probs"], + base_action_log_probs=batch["base_action_log_probs"], + values=batch["values"], + returns=batch["returns"], + advantages=batch["advantages"], + attention_mask=batch["attention_mask"], + loss_mask=batch["loss_mask"], + action_mask=batch["response_mask"], + num_actions=batch.metadata["response_length"], # int + rollout_logprobs=( + batch["rollout_logprobs"] if "rollout_logprobs" in batch else None + ), + # additional info + # can be used to log metrics etc for micro-batches in the worker + info={}, + # propagate metadata as is + metadata=batch.metadata, + ) + return exp + + +def masked_mean( + tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: Optional[int] = None +) -> torch.Tensor: + if mask is None: + return tensor.mean(axis=dim) + return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim).clamp(min=1.0) + + +def _safe_exp_delta( + delta: torch.Tensor, clip: float = 20.0, out_dtype=None +) -> torch.Tensor: + """ + Clamp the delta before exponentiating to avoid potential overflow. + """ + y = torch.clamp(delta.to(torch.float32), -clip, clip).exp() + return y.to(out_dtype or delta.dtype) + + +def ppo_policy_loss( + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + config, + loss_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, float]: + """Compute dual clip PPO policy loss.""" + ratio = _safe_exp_delta( + log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype + ) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages + loss = -torch.min(surr1, surr2) + clip_ratio = ( + masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() + ) + clip_pg_losses1 = loss + pg_losses3 = -advantages * config.clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + loss = loss = masked_mean(loss, loss_mask) + return loss, clip_ratio + + +def get_test_training_batch(model_name, batch_size=4) -> TrainingInputBatch: + """ + Returns a test training batch with padded seqs and attention masks + + Gives a batch of 4 sequences with variable amounts of left padding, and variable response lengths/amounts of right padding + Attention masks are 1 for non-padding tokens, 0 for padding tokens + The rest of the fields are filled with dummy data + """ + assert batch_size % 4 == 0, "batch size must be divisible by 4" + num_repeats = batch_size // 4 + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + sentences = [ + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.", + "<|im_start|>user\nThe selling price of a bicycle that had sold $220 last year was increased by 15", + "What is the new price? Let's think step by step and output the final answer after `####`.<|im_end|>\n", + "<|im_start|>assistant\nTo find the new price of the bicycle after the increase,", + ] * num_repeats + + sequences = [tokenizer.encode(sentence) for sentence in sentences] + attention_masks = [[1] * len(seq) for seq in sequences] + num_actions = 10 + # max seq len 1 longer than the longest sequence so we always have some padding + max_seq_length = max([len(seq) for seq in sequences]) + 7 + + pad_token_id = tokenizer.pad_token_id + pad_before = [4, 0, 1, 6] * num_repeats + pad_after = [ + max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences) + ] + + for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): + sequences[i] = ( + [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after + ) + attention_masks[i] = [0] * pad_before + attention_masks[i] + [0] * pad_after + + attention_masks = torch.tensor(attention_masks) + sequences = torch.tensor(sequences) + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_masks, + "action_log_probs": torch.tensor([[0.1] * num_actions] * batch_size), + "base_action_log_probs": torch.tensor([[0.2] * num_actions] * batch_size), + "rollout_logprobs": torch.tensor([[0.11] * num_actions] * batch_size), + "values": torch.tensor([[0.1] * num_actions] * batch_size), + "returns": torch.tensor([[0.1] * num_actions] * batch_size), + "advantages": torch.tensor([[0.5] * num_actions] * batch_size), + "loss_mask": torch.tensor([[1] * num_actions] * batch_size), + "response_mask": torch.tensor([[1] * num_actions] * batch_size), + } + ) + data.metadata = {"response_length": num_actions} + return data From e16da1e7079599ea5b9d4f872d18b02e50c86c27 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 24 Nov 2025 19:47:00 +0000 Subject: [PATCH 08/15] change dp-pp init ordering to put things on same node --- megatron_ray_fault_tolerant/main.py | 17 +++++++++++++---- megatron_ray_fault_tolerant/megatron_actor.py | 11 ++++++----- megatron_ray_fault_tolerant/megatron_utils.py | 1 + megatron_ray_fault_tolerant/pyproject.toml | 10 ---------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index b64b535..15eb7da 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -38,11 +38,11 @@ class TransformerConfig: @dataclass class MegatronConfig: - tensor_model_parallel_size: int = 1 - pipeline_model_parallel_size: int = 1 + tensor_model_parallel_size: int = 2 + pipeline_model_parallel_size: int = 2 context_parallel_size: int = 1 expert_model_parallel_size: int = 1 - expert_tensor_parallel_size: int = 1 + expert_tensor_parallel_size: int = None ddp_config: DDPConfig = field(default_factory=DDPConfig) optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) transformer_config: TransformerConfig = field(default_factory=TransformerConfig) @@ -52,7 +52,7 @@ class MegatronConfig: class Config: model: str = "Qwen/Qwen3-0.6B" # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it - num_nodes: int = 1 + num_nodes: int = 2 num_gpus_per_node: int = 4 mini_batch_size: int = 16 num_spare_gpus: int = 4 @@ -70,6 +70,15 @@ class Config: def main(): config = Config() # create placement group including spare gpus + + # need to set these env vars to avoid nccl error on nodes not supporting p2p + runtime_env = { + "env_vars": { + "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + } + } + ray.init(runtime_env=runtime_env) pg = placement_group( [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node + [{"GPU": 1, "CPU": 1}] * config.num_spare_gpus, diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py index c1789de..e6d330c 100644 --- a/megatron_ray_fault_tolerant/megatron_actor.py +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -127,6 +127,7 @@ def init_worker_process_group(self): use_sharp=False, context_parallel_size=self.megatron_config.context_parallel_size, nccl_communicator_config_path=None, + order="tp-pp-dp", ) self.set_seed(self.seed) self.world_size = dist.get_world_size() @@ -504,15 +505,15 @@ def load_checkpoint( def offload_to_cpu(self): self.all_buffer_sizes = offload_megatron_grads_to_cpu(self.actor_module) - offload_megatron_model_to_cpu(self.actor_module) - offload_megatron_optimizer(self.optimizer) + self.all_model_weights_and_sizes = offload_megatron_model_to_cpu(self.actor_module) + self.all_optimizer_weights_and_sizes = offload_megatron_optimizer(self.optimizer) torch.cuda.synchronize() torch.cuda.empty_cache() def backload_to_gpu(self): - load_megatron_grads_to_gpu(self.actor_module) - load_megatron_model_to_gpu(self.actor_module) - load_megatron_optimizer(self.optimizer) + load_megatron_grads_to_gpu(self.actor_module, self.all_buffer_sizes) + load_megatron_model_to_gpu(self.actor_module, self.all_model_weights_and_sizes) + load_megatron_optimizer(self.optimizer, self.all_optimizer_weights_and_sizes) torch.cuda.synchronize() torch.cuda.empty_cache() diff --git a/megatron_ray_fault_tolerant/megatron_utils.py b/megatron_ray_fault_tolerant/megatron_utils.py index 4a0f015..e1e3175 100644 --- a/megatron_ray_fault_tolerant/megatron_utils.py +++ b/megatron_ray_fault_tolerant/megatron_utils.py @@ -114,6 +114,7 @@ def offload_megatron_model_to_cpu(models): - fp32 main_parameter chunked in model and dp group - fp32 optimizer state chunked in model and dp group """ + # all_model_weights for model_chunk in models: if isinstance(model_chunk, DDP): model_chunk_all_buffers = [ diff --git a/megatron_ray_fault_tolerant/pyproject.toml b/megatron_ray_fault_tolerant/pyproject.toml index 51be3c3..fb43490 100644 --- a/megatron_ray_fault_tolerant/pyproject.toml +++ b/megatron_ray_fault_tolerant/pyproject.toml @@ -32,16 +32,6 @@ dependencies = [ "loguru", "jaxtyping", "s3fs", - # Make sure to change the flash attention source (under tool.uv.sources) above to a compatible version (<= 2.7.4.post1) for TransformerEngine==2.5.0 - # https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl - # For single node: build transformer-engine separately first, and uncomment the transformer-engine library import below - # uv pip install "torch==2.7.1" - # uv pip install "nvidia-cudnn-cu12>=9.3" - # export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" - # export CPATH="$CUDNN_PATH/include:${CPATH:-}" - # export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}" - # uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose - # "transformer-engine[pytorch]==2.5.0", "transformer-engine[pytorch]==2.7.0", "flash-attn==2.7.4.post1", "vllm==0.10.1.1", From 993215574a6e0f0c3eab64ee689e41ca4920d128 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Mon, 24 Nov 2025 11:49:12 -0800 Subject: [PATCH 09/15] Add megatron_ray_fault_tolerant example with comprehensive fault tolerance --- megatron_ray_fault_tolerant/.gitignore | 1 + .../.pre-commit-config.yaml | 20 + megatron_ray_fault_tolerant/Dockerfile | 34 + megatron_ray_fault_tolerant/README.md | 191 ++++ megatron_ray_fault_tolerant/dispatch.py | 299 ++++++ megatron_ray_fault_tolerant/file_io.py | 321 ++++++ megatron_ray_fault_tolerant/job.yaml | 45 + megatron_ray_fault_tolerant/main.py | 190 ++++ megatron_ray_fault_tolerant/megatron_actor.py | 934 ++++++++++++++++++ .../megatron_model_utils.py | 442 +++++++++ .../megatron_model_wrapper.py | 171 ++++ megatron_ray_fault_tolerant/megatron_utils.py | 465 +++++++++ megatron_ray_fault_tolerant/optimizer.py | 103 ++ megatron_ray_fault_tolerant/pyproject.toml | 98 ++ megatron_ray_fault_tolerant/run.sh | 1 + megatron_ray_fault_tolerant/training_batch.py | 371 +++++++ megatron_ray_fault_tolerant/utils.py | 286 ++++++ 17 files changed, 3972 insertions(+) create mode 100644 megatron_ray_fault_tolerant/.gitignore create mode 100644 megatron_ray_fault_tolerant/.pre-commit-config.yaml create mode 100644 megatron_ray_fault_tolerant/Dockerfile create mode 100644 megatron_ray_fault_tolerant/README.md create mode 100644 megatron_ray_fault_tolerant/dispatch.py create mode 100644 megatron_ray_fault_tolerant/file_io.py create mode 100644 megatron_ray_fault_tolerant/job.yaml create mode 100644 megatron_ray_fault_tolerant/main.py create mode 100644 megatron_ray_fault_tolerant/megatron_actor.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_utils.py create mode 100644 megatron_ray_fault_tolerant/megatron_model_wrapper.py create mode 100644 megatron_ray_fault_tolerant/megatron_utils.py create mode 100644 megatron_ray_fault_tolerant/optimizer.py create mode 100644 megatron_ray_fault_tolerant/pyproject.toml create mode 100755 megatron_ray_fault_tolerant/run.sh create mode 100644 megatron_ray_fault_tolerant/training_batch.py create mode 100644 megatron_ray_fault_tolerant/utils.py diff --git a/megatron_ray_fault_tolerant/.gitignore b/megatron_ray_fault_tolerant/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/megatron_ray_fault_tolerant/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/.pre-commit-config.yaml b/megatron_ray_fault_tolerant/.pre-commit-config.yaml new file mode 100644 index 0000000..5d51437 --- /dev/null +++ b/megatron_ray_fault_tolerant/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.9 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + exclude: (^(skyagent)/.*)$ + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + exclude: (^(skyagent)/.*)$ + + # Detect secrets and sensitive information + - repo: https://github.com/gitleaks/gitleaks + rev: v8.24.2 + hooks: + - id: gitleaks \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/Dockerfile b/megatron_ray_fault_tolerant/Dockerfile new file mode 100644 index 0000000..787c1c7 --- /dev/null +++ b/megatron_ray_fault_tolerant/Dockerfile @@ -0,0 +1,34 @@ +FROM anyscale/ray:2.51.0-slim-py312-cu128 + +RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev + +# the cuda compiler here is needed for deepspeed +RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \ + && sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run + +RUN curl -LsSf https://astral.sh/uv/0.9.4/install.sh | sh +RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc + + +RUN sudo apt-get update \ + && sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \ + libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \ + && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \ + && sudo apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +# ---------- PyTorch + cuDNN + Transformer Engine ---------- +# PyTorch + cuDNN + Transformer Engine +RUN pip install --no-cache-dir "torch==2.7.1" "nvidia-cudnn-cu12>=9.3" && \ + CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" && \ + sudo mkdir -p /opt && sudo ln -sfn "$CUDNN_PATH" /opt/cudnn && \ + echo "/opt/cudnn/lib" | sudo tee /etc/ld.so.conf.d/cudnn.conf >/dev/null && sudo ldconfig + +ENV CUDNN_PATH=/opt/cudnn +ENV CPATH=${CUDNN_PATH}/include:${CPATH} +ENV LD_LIBRARY_PATH=${CUDNN_PATH}/lib:${LD_LIBRARY_PATH} + +RUN pip install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.5.0" +# -------------------- diff --git a/megatron_ray_fault_tolerant/README.md b/megatron_ray_fault_tolerant/README.md new file mode 100644 index 0000000..b7abecc --- /dev/null +++ b/megatron_ray_fault_tolerant/README.md @@ -0,0 +1,191 @@ +# Megatron + Ray Fault Tolerant Training + +This example implements PPO-style distributed training using Megatron and Ray with comprehensive fault tolerance capabilities. The system can automatically recover from actor failures during training by utilizing backup actors and re-initializing process groups. + +## Key Features + +### Fault Tolerance Mechanisms + +1. **Actor Health Monitoring**: Continuously monitors the health of distributed training actors +2. **Backup Actor Pool**: Pre-allocated backup actors ready to replace failed workers +3. **Automatic Recovery**: Seamlessly recovers from failures by: + - Detecting dead actors + - Destroying old process groups + - Replacing failed actors with backup actors + - Re-initializing process groups with new world size + - Reloading model and optimizer state from checkpoints + +4. **Distributed Checkpointing**: Implements efficient sharded checkpoint saving/loading using Megatron's distributed checkpointing +5. **Process Group Management**: Handles NCCL process group initialization, destruction, and re-initialization + +### Parallelism Support + +- **Data Parallelism (DP)**: Distributes training data across multiple GPUs +- **Tensor Parallelism (TP)**: Splits model tensors across GPUs +- **Pipeline Parallelism (PP)**: Distributes model layers across GPUs +- **Context Parallelism (CP)**: Enables sequence parallelism for long contexts + +### Advanced Training Features + +- **PPO Training**: Implements Proximal Policy Optimization with micro-batch accumulation +- **Mixed Precision**: Supports BF16 training for improved performance +- **Gradient Accumulation**: Handles micro-batches with automatic gradient accumulation +- **Distributed Optimizer**: Uses Megatron's distributed optimizer for memory efficiency + +## Architecture + +### Core Components + +1. **MegatronActor** (`megatron_actor.py`): + - Individual training actor wrapping Megatron models + - Handles model initialization, forward/backward passes, and checkpointing + - Supports dynamic process group re-initialization + +2. **MegatronActorGroup** (`megatron_actor.py`): + - Manages a group of distributed actors + - Implements fault recovery logic + - Coordinates distributed training operations + +3. **Dispatch System** (`dispatch.py`): + - **MeshDispatch**: Distributes data across the device mesh (DP, SP, TP, PP) + - **PassThroughDispatch**: Broadcasts same data/commands to all actors + - Handles data sharding and result collection + +4. **Training Batch** (`training_batch.py`): + - Defines input/output batch structures for PPO training + - Supports chunking and concatenation for distributed operations + +5. **Checkpoint I/O** (`file_io.py`): + - Cloud-aware file I/O supporting S3, GCS, and local storage + - Efficient checkpoint upload/download with parallel transfers + +## Getting Started + +### Quick Start + +```bash +uv run --isolated main.py +``` + +This will: +1. Create a placement group with workers and backup GPUs +2. Initialize the actor group and model +3. Run a training step +4. Save a checkpoint +5. Simulate a failure by killing actors +6. Recover from the failure using backup actors +7. Resume training after recovery + +### Configuration + +Edit the `Config` class in `main.py` to customize: + +```python +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" # HuggingFace model name + num_nodes: int = 1 + num_gpus_per_node: int = 4 + num_spare_gpus: int = 4 # Backup actors for fault tolerance + mini_batch_size: int = 16 + micro_train_batch_size_per_gpu: int = 2 + + # Megatron parallelism settings + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) +``` + +### Megatron Parallelism Configuration + +```python +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 # TP degree + pipeline_model_parallel_size: int = 1 # PP degree + context_parallel_size: int = 1 # CP degree + expert_model_parallel_size: int = 1 # For MoE models +``` + +## Fault Recovery Workflow + +1. **Training Phase**: + - Actors perform distributed training using Megatron + - Periodic checkpoints saved to cloud storage + +2. **Failure Detection**: + - System detects actor failures via health checks + - Identifies affected data parallel groups + +3. **Recovery Process**: + - Destroy old process groups on healthy actors + - Pop backup actors from the backup pool + - Insert backup actors at failed ranks + - Update world size and reassign ranks + - Re-initialize process groups with new configuration + - Reload model/optimizer state from checkpoint + +4. **Resume Training**: + - Continue training with recovered actor group + - No loss of training progress (from last checkpoint) + +## Advanced Usage + +### Custom Dispatch Types + +Register custom dispatch strategies: + +```python +from dispatch import register_dispatch_type, Dispatch + +class CustomDispatch(Dispatch): + # Implement dispatch, collect, and validate methods + pass + +register_dispatch_type("custom", CustomDispatch) +``` + +### CPU Offloading (Experimental) + +For faster recovery, offload model/optimizer state to CPU memory: + +```python +# Before failure +ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + +# After recovery, on healthy actors +ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) +``` + +## Dependencies + +See `pyproject.toml` for full dependency list. Key dependencies: +- Ray for distributed orchestration +- Megatron-Core for model parallelism +- PyTorch with CUDA support +- Transformers for model loading +- vLLM and related libraries + +## Running on Anyscale + +Submit the job using: + +```bash +anyscale job submit -f job.yaml +``` + +The job configuration in `job.yaml` specifies: +- Container image with dependencies +- GPU instance types (g6e.12xlarge with 4xL4) +- Resource limits and scaling +- Environment variables for NCCL configuration + +## Limitations and Future Work + +- Virtual pipeline parallelism not yet supported +- CPU offloading optimization in progress +- Async checkpoint saving planned for future releases + +## References + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +- [Ray Documentation](https://docs.ray.io/) +- [Anyscale Platform](https://docs.anyscale.com/) diff --git a/megatron_ray_fault_tolerant/dispatch.py b/megatron_ray_fault_tolerant/dispatch.py new file mode 100644 index 0000000..9949c48 --- /dev/null +++ b/megatron_ray_fault_tolerant/dispatch.py @@ -0,0 +1,299 @@ +"""Defines dispatch and collect logic for distributed training""" + +from dataclasses import dataclass +from ray.actor import ActorHandle +from typing import List, Tuple, Optional, Dict, Type, Any +import asyncio +from abc import ABC, abstractmethod +import ray +from ray import ObjectRef +from training_batch import TrainingInputBatch, TrainingOutputBatch +import inspect + + +@dataclass +class MeshRank: + """Represents a rank in the device mesh. + + This is a tuple of (DP, SP, TP, PP) ranks. + """ + + dp: int + sp: int + tp: int + pp: int + + world_size: int + dp_size: int + pp_size: int + + def is_collection_dp_rank(self) -> bool: + """Check if this rank is a DP rank to collect from + + This is the rank with (SP=0, TP=0, PP=pp_size-1) + + Note: double check this for ETP > 1 (but this is not a typically used case) + """ + return self.tp == 0 and self.pp == self.pp_size - 1 and self.sp == 0 + + def __str__(self) -> str: + return f"MeshRank(dp={self.dp}, sp={self.sp}, tp={self.tp}, pp={self.pp}, world_size={self.world_size}, dp_size={self.dp_size}, pp_size={self.pp_size})" + + def __repr__(self) -> str: + return self.__str__() + + +@dataclass +class ActorInfo: + """Actor information for distributed training. + + This includes the actor handle and the rank in the device mesh. + """ + + handle: ActorHandle + rank: MeshRank + + +class Dispatch(ABC): + """Base class for dispatch types + + Dispatch types are responsible for: + - dispatching method calls to actors handling data sharding if necessary + - collecting results from actors and concatenating results if necessary + - validating arguments for dispatch + """ + + @classmethod + @abstractmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + """Dispatches method calls to the actors with data sharing if necessary.""" + pass + + @classmethod + @abstractmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors asynchronously in an asyncio-compatible way.""" + pass + + @classmethod + @abstractmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + """Collects results from the actors synchronously and returns a `TrainingOutputBatch`.""" + pass + + @classmethod + @abstractmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + """Validate and process arguments for dispatch. + + Returns: + Tuple of (args, kwargs) to be passed to dispatch + """ + pass + + +class MeshDispatch(Dispatch): + """Mesh dispatch type to dispatch data to a group of actors along the device mesh. + + Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. + The actor method should accept a single argument - the data batch. + + For data dispatch: + + * The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism. + * Each actor with the same DP rank processes the same data chunk in parallel. + + For data collection: + + * Data is collected only from the primary rank of each model/sequence parallel group. + * The primary rank is defined as the rank with (SP=0, TP=0, PP=0). + * The collected chunks are concatenated in order of DP rank to reconstruct the full data. + + Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1: + + * Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, + and all actors with DP rank 1 process the second chunk. + * Data collection: Only two actors contribute to the final output - the primary rank from each DP group: + (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order. + + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch + ) -> List[ObjectRef]: + assert len(actor_infos) > 0, "actor_infos must be a non-empty list" + object_refs = [] + dp_size = actor_infos[0].rank.dp_size + assert ( + len(data) % dp_size == 0 + ), "data batch size must be divisible by dp_size, got {} and {}".format( + len(data), dp_size + ) + chunk_size = len(data) // dp_size + data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size) + + for actor_info in actor_infos: + # index into tensordict to get the correct data to send + data_to_send = data_chunks[actor_info.rank.dp] + object_refs.append(getattr(actor_info.handle, method).remote(data_to_send)) + return object_refs + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + assert len(actor_infos) == len( + object_refs + ), "`actor_infos` and `object_refs` must have the same length" + all_objects = ray.get(object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + # all should be none + assert all( + obj is None for obj in all_objects + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + sig = inspect.signature(cls.dispatch) + # pass dummy actor_infos and method_name + bound_args = sig.bind([], "dummy", *args, **kwargs) + bound_args.apply_defaults() + data = bound_args.arguments.get("data") + + # Check if there are any extra arguments + if len(bound_args.arguments) > 3: # data, actor_infos, method_name + # remove actor_infos and method_name - not added by user + bound_args.arguments.pop("actor_infos") + bound_args.arguments.pop("method") + raise ValueError( + f"MeshDispatch only accepts 'data' as an argument, got extra args: {bound_args.arguments}" + ) + + data = bound_args.arguments.get("data") + if not isinstance(data, TrainingInputBatch): + raise ValueError( + f"For MeshDispatch, `data` entry should be a `TrainingInput`, got {data}" + ) + args = (data,) + kwargs = {} + return args, kwargs + + +class PassThroughDispatch(Dispatch): + """PassThrough dispatch type to dispatch data to a group of actors without any sharding. + + This is useful for cases where we want to run the same method on all the actors. + Supports methods with any number of arguments. + """ + + @classmethod + def dispatch( + cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs + ) -> List[ObjectRef]: + return [ + getattr(actor_info.handle, method).remote(*args, **kwargs) + for actor_info in actor_infos + ] + + @classmethod + async def async_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + all_objects = await asyncio.gather(*object_refs) + if len(all_objects) and all_objects[0] is not None: + return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects) + return + + @classmethod + def sync_collect( + cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef] + ) -> Optional[TrainingOutputBatch]: + data_batches = ray.get(object_refs) + if len(data_batches) > 0 and data_batches[0] is not None: + assert isinstance( + data_batches[0], TrainingOutputBatch + ), "data_batches must be a list of `TrainingOutputBatch` objects" + return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches) + # all should be none + assert all( + obj is None for obj in data_batches + ), "Got a mix of `None` and non-`None` objects" + return + + @classmethod + def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]: + # no validation needed just pass everything + return args, kwargs + + +class DispatchRegistry: + _registry: Dict[str, Type[Dispatch]] = { + "mesh": MeshDispatch, + "pass_through": PassThroughDispatch, + } + + @classmethod + def register(cls, name: str, dispatch_class: Type[Dispatch]) -> None: + """Register a new dispatch type.""" + assert issubclass(dispatch_class, Dispatch) + cls._registry[name] = dispatch_class + + @classmethod + def get(cls, name: str) -> Type[Dispatch]: + """Get a registered dispatch type.""" + if name not in cls._registry: + raise KeyError(f"Dispatch type '{name}' not registered") + return cls._registry[name] + + @classmethod + def list_registered(cls) -> Dict[str, Type[Dispatch]]: + """List all registered dispatch types.""" + return cls._registry + + +def register_dispatch_type(name: str, dispatch_class: Type) -> None: + DispatchRegistry.register(name, dispatch_class) + + +def concatenate_outputs_after_mesh_dispatch( + actor_infos: List[ActorInfo], data_batches: List[TrainingOutputBatch] +) -> TrainingOutputBatch: + """Concatenate data batches from different ranks after mesh dispatch. + + - Data is collected only from the primary DP rank. + - The collected chunks are concatenated in order of DP rank to reconstruct the full data. + """ + assert len(actor_infos) == len( + data_batches + ), "`actor_infos` and `data_batches` must have the same length" + shards = [] + # collect in-order + dp_rank_to_shard = {} + for actor_info, data_batch in zip(actor_infos, data_batches): + if actor_info.rank.is_collection_dp_rank(): + dp_rank = actor_info.rank.dp + dp_rank_to_shard[dp_rank] = data_batch + for i in range(actor_infos[0].rank.dp_size): + shards.append(dp_rank_to_shard[i]) + return TrainingOutputBatch.cat(shards) diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py new file mode 100644 index 0000000..932adbe --- /dev/null +++ b/megatron_ray_fault_tolerant/file_io.py @@ -0,0 +1,321 @@ +""" +File I/O utilities for handling both local filesystem and cloud storage (S3/GCS). + +This module provides a unified interface for file operations that works with: +- Local filesystem paths +- S3 paths (s3://bucket/path) +- Google Cloud Storage paths (gs://bucket/path or gcs://bucket/path) + +Uses fsspec for cloud storage abstraction. +""" + +import os +import tempfile +from contextlib import contextmanager +import fsspec +from loguru import logger +from datetime import datetime, timezone, timedelta + +# Optional AWS deps (present when s3fs is installed) +try: + import botocore.session as _botocore_session + from botocore.exceptions import ClientError + + _HAS_BOTOCORE = True +except Exception: + _HAS_BOTOCORE = False + + class ClientError(Exception): # fallback type + pass + + +_S3_FS = None # type: ignore + + +def get_s3_fs(): + """Return a cached S3 filesystem instance, creating it once.""" + global _S3_FS + if _S3_FS is None: + _S3_FS = fsspec.filesystem("s3") + return _S3_FS + + +def s3_expiry_time(): + """Return botocore credential expiry (datetime in UTC) or None.""" + if not _HAS_BOTOCORE: + return None + try: + sess = _botocore_session.get_session() + creds = sess.get_credentials() + if not creds: + return None + return getattr(creds, "expiry_time", None) or getattr( + creds, "_expiry_time", None + ) + except Exception: + return None + + +def s3_refresh_if_expiring(fs) -> None: + """ + Simple refresh: + - If expiry exists and is within 300s (or past), refresh with fs.connect(refresh=True). + - Otherwise, do nothing. + """ + exp = s3_expiry_time() + if not exp: + return + now = datetime.now(timezone.utc) + if now >= exp - timedelta(seconds=300): + try: + fs.connect(refresh=True) # rebuild session + except Exception: + pass + + +def call_with_s3_retry(fs, fn, *args, **kwargs): + """ + Wrapper for calling an S3 method. If it fails with ExpiredToken, force refresh once and retry. + """ + try: + return fn(*args, **kwargs) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fn(*args, **kwargs) + raise + + +def is_cloud_path(path: str) -> bool: + """Check if the given path is a cloud storage path.""" + return path.startswith(("s3://", "gs://", "gcs://")) + + +def _get_filesystem(path: str): + """Get the appropriate filesystem for the given path.""" + if not is_cloud_path(path): + return fsspec.filesystem("file") + + proto = path.split("://", 1)[0] + if proto == "s3": + fs = get_s3_fs() + s3_refresh_if_expiring(fs) + return fs + return fsspec.filesystem(proto) + + +def open_file(path: str, mode: str = "rb"): + """Open a file using fsspec, works with both local and cloud paths.""" + if not is_cloud_path(path): + return fsspec.open(path, mode) + + fs = _get_filesystem(path) + norm = fs._strip_protocol(path) + try: + return fs.open(norm, mode) + except ClientError as e: + code = getattr(e, "response", {}).get("Error", {}).get("Code") + if code in { + "ExpiredToken", + "ExpiredTokenException", + "RequestExpired", + } and hasattr(fs, "connect"): + try: + fs.connect(refresh=True) + except Exception: + pass + return fs.open(norm, mode) + raise + + +def makedirs(path: str, exist_ok: bool = True) -> None: + """Create directories. Only applies to local filesystem paths.""" + if not is_cloud_path(path): + os.makedirs(path, exist_ok=exist_ok) + + +def exists(path: str) -> bool: + """Check if a file or directory exists.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.exists, path) + return fs.exists(path) + + +def isdir(path: str) -> bool: + """Check if path is a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.isdir, path) + return fs.isdir(path) + + +def list_dir(path: str) -> list[str]: + """List contents of a directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + return call_with_s3_retry(fs, fs.ls, path, detail=False) + return fs.ls(path, detail=False) + + +def remove(path: str) -> None: + """Remove a file or directory.""" + fs = _get_filesystem(path) + if is_cloud_path(path) and path.startswith("s3://"): + if call_with_s3_retry(fs, fs.isdir, path): + call_with_s3_retry(fs, fs.rm, path, recursive=True) + else: + call_with_s3_retry(fs, fs.rm, path) + return + if fs.isdir(path): + fs.rm(path, recursive=True) + else: + fs.rm(path) + + +def upload_directory(local_path: str, cloud_path: str) -> None: + """Upload a local directory to cloud storage. + + Uploads the contents of local_path to cloud_path, not the directory itself. + This ensures consistent behavior across all ranks by explicitly uploading each file. + """ + if not is_cloud_path(cloud_path): + raise ValueError(f"Destination must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + + # Normalize paths: ensure cloud_path ends with / to indicate directory + cloud_path_normalized = cloud_path.rstrip("/") + "/" + + # Walk the local directory and upload each file explicitly + # This ensures we upload contents, not the directory as a subdirectory + for root, dirs, files in os.walk(local_path): + for file in files: + local_file_path = os.path.join(root, file) + # Get relative path from local_path to maintain directory structure + rel_path = os.path.relpath(local_file_path, local_path) + # Construct remote path: cloud_path/rel_path + remote_file_path = cloud_path_normalized + rel_path + + if cloud_path.startswith("s3://"): + # For S3, strip protocol for fsspec operations + remote_file_path_stripped = fs._strip_protocol(remote_file_path) + # Ensure parent directories exist in S3 (fsspec handles this automatically) + call_with_s3_retry( + fs, fs.put, local_file_path, remote_file_path_stripped + ) + else: + fs.put(local_file_path, remote_file_path) + + logger.info(f"Uploaded contents of {local_path} to {cloud_path}") + + +def download_directory(cloud_path: str, local_path: str) -> None: + """Download a cloud directory to local storage.""" + if not is_cloud_path(cloud_path): + raise ValueError(f"Source must be a cloud path, got: {cloud_path}") + + fs = _get_filesystem(cloud_path) + cloud_path_normalized = cloud_path.rstrip("/") + "/" + os.makedirs(local_path, exist_ok=True) + + # List all files and download each one individually to download contents, not the folder + if cloud_path.startswith("s3://"): + remote_path_stripped = fs._strip_protocol(cloud_path_normalized) + all_files = call_with_s3_retry(fs, fs.find, remote_path_stripped, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(remote_path_stripped) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + call_with_s3_retry(fs, fs.get, remote_file, local_file_path) + else: + all_files = fs.find(cloud_path_normalized, detail=False) + for remote_file in all_files: + if remote_file.endswith("/"): + continue + rel_path = remote_file[len(cloud_path_normalized) :].lstrip("/") + local_file_path = os.path.join(local_path, rel_path) + parent_dir = os.path.dirname(local_file_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + fs.get(remote_file, local_file_path) + + logger.info(f"Downloaded {cloud_path} to {local_path}") + + +@contextmanager +def local_work_dir(output_path: str): + """ + Context manager that provides a local working directory. + + For local paths, returns the path directly. + For cloud paths, creates a temporary directory and uploads content at the end. + + Args: + output_path: The final destination path (local or cloud) + + Yields: + str: Local directory path to work with + + Example: + with local_work_dir("s3://bucket/model") as work_dir: + # Save files to work_dir + model.save_pretrained(work_dir) + # Files are automatically uploaded to s3://bucket/model at context exit + """ + if is_cloud_path(output_path): + with tempfile.TemporaryDirectory() as temp_dir: + try: + yield temp_dir + finally: + # Upload everything from temp_dir to cloud path + upload_directory(temp_dir, output_path) + logger.info(f"Uploaded directory contents to {output_path}") + else: + # For local paths, ensure directory exists and use it directly + makedirs(output_path, exist_ok=True) + yield output_path + + +@contextmanager +def local_read_dir(input_path: str): + """ + Context manager that provides a local directory with content from input_path. + + For local paths, returns the path directly. + For cloud paths, downloads content to a temporary directory. + + Args: + input_path: The source path (local or cloud) + + Yields: + str: Local directory path containing the content + + Example: + with local_read_dir("s3://bucket/model") as read_dir: + # Load files from read_dir + model = AutoModel.from_pretrained(read_dir) + """ + if is_cloud_path(input_path): + with tempfile.TemporaryDirectory() as temp_dir: + # Download everything from cloud path to temp_dir + download_directory(input_path, temp_dir) + logger.info(f"Downloaded directory contents from {input_path}") + yield temp_dir + else: + # For local paths, use directly (but check it exists) + if not exists(input_path): + raise FileNotFoundError(f"Path does not exist: {input_path}") + yield input_path diff --git a/megatron_ray_fault_tolerant/job.yaml b/megatron_ray_fault_tolerant/job.yaml new file mode 100644 index 0000000..f1c2de2 --- /dev/null +++ b/megatron_ray_fault_tolerant/job.yaml @@ -0,0 +1,45 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. + +name: megatron-fault-tolerance + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: + # Pin worker nodes to g6.xlarge (1xL4) so the vision workload lands on L4 GPUs. + worker_nodes: + - instance_type: g6e.12xlarge + min_nodes: 0 + max_nodes: 2 + min_resources: + CPU: 0 + GPU: 0 + max_resources: + CPU: 384 + GPU: 64 + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +env_vars: + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + NCCL_P2P_DISABLE: "1" + NCCL_SHM_DISABLE: "1" + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: uv run --isolated main.py + +# If there is an error, do not retry. +max_retries: 0 \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py new file mode 100644 index 0000000..b64b535 --- /dev/null +++ b/megatron_ray_fault_tolerant/main.py @@ -0,0 +1,190 @@ +import os +from dataclasses import dataclass, field +import ray +from typing import Optional, List +from megatron_actor import MegatronActorGroup +from ray.util.placement_group import placement_group + +import random +import time +from utils import get_test_training_batch, get_reordered_bundle_indices + + +@dataclass +class DDPConfig: + grad_reduce_in_fp32: bool = True + overlap_grad_reduce: bool = False + overlap_param_gather: bool = False + average_in_collective: bool = True + + +@dataclass +class OptimizerConfig: + lr: float = 1.0e-6 + weight_decay: float = 1e-2 + max_grad_norm: float = 1.0 + offload_after_step: bool = True + num_warmup_steps: int = 0 + scheduler: str = "constant_with_warmup" + + +@dataclass +class TransformerConfig: + recompute_granularity: Optional[str] = None + recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"]) + recompute_method: Optional[str] = None + recompute_num_layers: Optional[int] = None + + +@dataclass +class MegatronConfig: + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + context_parallel_size: int = 1 + expert_model_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + ddp_config: DDPConfig = field(default_factory=DDPConfig) + optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) + transformer_config: TransformerConfig = field(default_factory=TransformerConfig) + + +@dataclass +class Config: + model: str = "Qwen/Qwen3-0.6B" + # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it + num_nodes: int = 1 + num_gpus_per_node: int = 4 + mini_batch_size: int = 16 + num_spare_gpus: int = 4 + micro_train_batch_size_per_gpu: int = 2 + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) + ckpt_dir: str = ( + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt3/" + ) + # algorithm config + eps_clip_low: float = 0.2 + eps_clip_high: float = 0.2 + clip_ratio_c: float = 3.0 + + +def main(): + config = Config() + # create placement group including spare gpus + pg = placement_group( + [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node + + [{"GPU": 1, "CPU": 1}] * config.num_spare_gpus, + strategy="PACK", + ) + ray.get(pg.ready(), timeout=1200) + # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 + reordered_bundle_indices = get_reordered_bundle_indices(pg) + + actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_nodes, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[:-config.num_spare_gpus], + ) + actor_group.initiate_worker_process_group() + ray.get(actor_group.async_init_model(config.model)) + + # potentially need some time for dependencies like transformer-engine-torch to build on worker nodes (this is something good to warm start...) + backup_actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_spare_gpus // config.num_gpus_per_node, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices[-config.num_spare_gpus:], + ) + # just place but don't initiate the worker process group for the backup actor group + # call a function to make sure the actors are placed + ray.get(backup_actor_group.async_run_method_no_dispatch("get_gpu_id")) + + # train on one batch + batch = get_test_training_batch(config.model, batch_size=32) + print("Starting training step 1...") + start_time = time.time() + ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) + print(f"Training step 1 took {time.time() - start_time:.2f} seconds") + + # save checkpoint + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") + + # TODO: add a cpu offload (or cpu save memory) call here + # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + + # TODO: run another training batch here and save results but don't save checkpoint + + # randomly kill an actor to simulate fault tolerance scenario + # TODO: go deeper into the actor code and throw an exception on a given node and catch it here + print("Simulating failure and recovery...") + start_time = time.time() + + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) + # get the whole dp group associated with the failed actor + dp_group_actors = [] + for actor_info in actor_group.actor_infos: + if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: + dp_group_actors.append(actor_info) + print( + f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." + ) + for actor_info in dp_group_actors: + ray.kill(actor_info.handle) + + # Destroy process groups on all actors (including dead ones, which will fail gracefully) + print("Destroying old process groups...") + try: + ray.get( + actor_group.async_run_ray_method( + "pass_through", "destroy_worker_process_group" + ) + ) + except Exception as e: + print(f"Some actors failed during destroy (expected): {e}") + + for i, actor_info in enumerate(actor_group.actor_infos): + is_alive = actor_group._check_actor_alive(actor_info.handle) + print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") + + # Recover from failure: remove dead actors and re-initialize process group + print("Recovering from actor failure...") + actor_group.recover_from_failure(backup_actor_group) + + # load checkpoint on all actors + # TODO: improve the logic here + # we want to only call load checkpoint on the actors that are fresh + # on previously healthy actors we want to restore weights and optimizer state from cpu memory + # ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu"), actor_ids=[previously healthy actor ids]) + # only for new actors, we want to load the checkpoint + ray.get( + actor_group.async_run_ray_method( + "pass_through", "load_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Recovery took {time.time() - start_time:.2f} seconds") + + # TODO: check that results here are the same as before the failure when resuming from checkpoint + # Test that training still works after recovery + print("Testing training after recovery...") + batch_after_recovery = get_test_training_batch(config.model, batch_size=32) + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "ppo_train", batch_after_recovery + ) + ) + print(f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds") + print("Recovery successful! Training works with remaining actors.") + + +if __name__ == "__main__": + main() diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py new file mode 100644 index 0000000..c1789de --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -0,0 +1,934 @@ +import logging +import os +import random +import socket +from dataclasses import asdict +from tqdm import tqdm +from typing import Optional, Dict, Any, List +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +import ray +from ray import ObjectRef +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer +from loguru import logger + +# megatron +from megatron.bridge import AutoBridge +import megatron.core.parallel_state as mpu +from megatron.core import dist_checkpointing +from megatron.core.dist_checkpointing.strategies import base as ckpt_base +from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) + +# local imports +import file_io as io # local io module to support cloud storage for checkpointing +from training_batch import TrainingOutputBatch +from optimizer import ( + init_megatron_optim_config, + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, +) +from megatron_model_wrapper import MegatronModelWrapper +from megatron_utils import ( + offload_megatron_model_to_cpu, + offload_megatron_optimizer, + load_megatron_model_to_gpu, + load_megatron_optimizer, + offload_megatron_grads_to_cpu, + load_megatron_grads_to_gpu, +) +from utils import BatchIterator +from dispatch import DispatchRegistry, Dispatch, ActorInfo, MeshRank + + +@ray.remote(num_gpus=1) +class MegatronActor: + def __init__( + self, + world_size, + rank, + local_rank, + master_addr, + master_port, + megatron_config, + seed, + cfg, + ): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self._world_size = world_size + self._rank = rank + self._local_rank = local_rank + self._master_addr = master_addr if master_addr else self._get_current_node_ip() + self._master_port = master_port if master_port else self._get_free_port() + os.environ["MASTER_ADDR"] = self._master_addr + os.environ["MASTER_PORT"] = str(self._master_port) + os.environ["WORLD_SIZE"] = str(self._world_size) + os.environ["RANK"] = str(self._rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + os.environ["LOCAL_RANK"] = "0" + self.megatron_config = megatron_config + self.seed = seed + self.cfg = cfg + + def get_node_local_rank(self): + return self._local_rank + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if torch.cuda.device_count() > 0: + from megatron.core import tensor_parallel + + tensor_parallel.model_parallel_cuda_manual_seed(seed) + + def init_worker_process_group(self): + """Initialize worker process group and megatron model parallel.""" + # Destroy any existing process group first to ensure clean state + if torch.distributed.is_initialized(): + try: + torch.distributed.destroy_process_group() + except Exception: + pass # Ignore errors if already destroyed + + # Initialize process group using environment variables + torch.distributed.init_process_group(backend="nccl") + + local_rank = int(os.environ.get("LOCAL_RANK", "-1")) + if local_rank != -1: + torch.cuda.set_device(local_rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.megatron_config.tensor_model_parallel_size, + pipeline_model_parallel_size=self.megatron_config.pipeline_model_parallel_size, + expert_model_parallel_size=self.megatron_config.expert_model_parallel_size, + expert_tensor_parallel_size=self.megatron_config.expert_tensor_parallel_size, + use_sharp=False, + context_parallel_size=self.megatron_config.context_parallel_size, + nccl_communicator_config_path=None, + ) + self.set_seed(self.seed) + self.world_size = dist.get_world_size() + self.mesh_rank = MeshRank( + dp=mpu.get_data_parallel_rank(), + sp=mpu.get_context_parallel_rank(), + tp=mpu.get_tensor_model_parallel_rank(), + pp=mpu.get_pipeline_model_parallel_rank(), + world_size=self._world_size, + dp_size=mpu.get_data_parallel_world_size(), + pp_size=mpu.get_pipeline_model_parallel_world_size(), + ) + + def get_mesh_rank(self): + return self.mesh_rank + + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + def print(self, *msg): + """Print only on rank 0""" + if dist.get_rank() == 0: + logger.info(*msg) + + @staticmethod + def _get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + def get_ray_node_id(self): + return ray.get_runtime_context().get_node_id() + + @staticmethod + def get_rng_state(): + """Get current RNG state for reproducibility""" + rng_state = { + "cpu": torch.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + + # Only save CUDA RNG state if CUDA is available and being used + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + rng_state["cuda"] = torch.cuda.get_rng_state() + + return rng_state + + @staticmethod + def load_rng_state(rng_state): + """Load RNG state for reproducibility""" + torch.set_rng_state(rng_state["cpu"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + + # Only restore CUDA RNG state if it was saved and CUDA is available + if ( + "cuda" in rng_state + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + torch.cuda.set_rng_state(rng_state["cuda"]) + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + def destroy_worker_process_group(self): + mpu.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Clear stale env vars + for env_var in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]: + if env_var in os.environ: + del os.environ[env_var] + + def reinit_model_after_recovery(self): + """Re-initialize model and optimizer after process group recovery. + + This is needed because the model and optimizer were created with the old + process group and still have references to old NCCL communicators. + + We need to fully reinitialize the provider and model to ensure they use + the new process group. + """ + if not hasattr(self, "_model_path") or self._model_path is None: + # Fall back to cfg.model if _model_path not set + if hasattr(self.cfg, "model"): + model_path = self.cfg.model + else: + logger.warning("No model path found, cannot re-initialize model") + return + else: + model_path = self._model_path + + num_training_steps = getattr(self, "_num_training_steps", 1e9) + + logger.info("Re-initializing model components after process group recovery...") + + # Re-initialize the bridge and provider with the new process group + # This ensures all NCCL communicators are created fresh + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # Recreate the DDP-wrapped module with the new process group + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + # Recreate optimizer with the new process group + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + # Recreate scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # Recreate model wrapper + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # Re-normalize mini batch size with new world size + self._normalize_mini_batch_size() + + logger.info("Model components re-initialized successfully") + + def update_world_size(self, new_world_size: int): + """Update the world_size stored in the actor.""" + self._world_size = new_world_size + os.environ["WORLD_SIZE"] = str(new_world_size) + + def update_rank(self, new_rank: int): + """Update the rank stored in the actor.""" + self._rank = new_rank + os.environ["RANK"] = str(new_rank) + + def update_master_addr_port(self, master_addr: str, master_port: int): + """Update the master address and port for process group initialization.""" + self._master_addr = master_addr + self._master_port = master_port + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + def _normalize_mini_batch_size(self): + """ + Normalize mini batch sizes to per-gpu mini batch sizes. + """ + if not hasattr(self, "mesh_rank") or self.mesh_rank is None: + raise RuntimeError( + "mesh_rank must be initialized before calling _normalize_mini_batch_size()" + ) + + dp_size = self.mesh_rank.dp_size + self.policy_mini_batch_size_per_gpu = self.cfg.mini_batch_size // dp_size + + def ppo_train(self, train_data) -> "TrainingOutputBatch": + """ + Overrides `PolicyWorkerBase.ppo_train` for megatron. + + Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the + worker MegatronModelWrapper.forward_backward_mini_batch method. + """ + dataloader = BatchIterator( + train_data, + sample_batch_size=self.cfg.micro_train_batch_size_per_gpu, + drop_last=False, + ) + + micro_batches_per_mini_batch = ( + self.policy_mini_batch_size_per_gpu + // self.cfg.micro_train_batch_size_per_gpu + ) + + self.optimizer.zero_grad() + pbar = tqdm( + dataloader, + desc="ppo train", + disable=not dist.get_rank() == 0, + ) + + micro_buffer = [] + for local_step, experience in enumerate(pbar): + experience.to_device(torch.cuda.current_device()) + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + } + ) + + if len(micro_buffer) == micro_batches_per_mini_batch: + # run mini-batch forward-backward and then one optimizer step + self.model.train() + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + seq_len = micro_buffer[0]["sequences"].shape[1] + micro_bsz = micro_buffer[0]["sequences"].shape[0] + + self.model.forward_backward_mini_batch( + micro_batches=micro_buffer, + seq_len=seq_len, + micro_batch_size=micro_bsz, + ) + + _, grad_norm, _ = self.optimizer.step() + self.scheduler.step(1) + self.optimizer.zero_grad() + + torch.distributed.barrier() + + def save_checkpoint(self, ckpt_dir: str): + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + node_local_rank = self.get_node_local_rank() + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + model = model[0] + if hasattr(model, "module"): + model = model.module + + # Create checkpoint directory if it doesn't exist. + if node_local_rank == 0: + io.makedirs(ckpt_dir, exist_ok=True) + + # All ranks wait for the checkpoint directory to be created before saving. + dist.barrier() + + # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. + sharded_state_dict = {} + model_sharded_state_dict = model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # Save RNG state. + sharded_state_dict["rng"] = self.get_rng_state() + + # Save the checkpoint across ranks in parallel. + save_strategy = get_default_save_sharded_strategy("torch_dist") + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + with io.local_work_dir(ckpt_dir) as work_dir: + # synchronous checkpointing for now + async_save_request = dist_checkpointing.save( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=work_dir, + sharded_strategy=save_strategy, + async_sharded_save=False, + validate_access_integrity=True, + ) + assert ( + async_save_request is None + ), "Async save is not yet supported for Megatron" + + dist.barrier() + ckpt_base.async_calls.close() + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + self.print(f"Checkpoint successfully saved to {ckpt_dir}") + + def load_checkpoint( + self, + ckpt_dir: str, + load_module_strict: bool = True, + load_optimizer_states: bool = True, + load_lr_scheduler_states: bool = True, + ): + if not ckpt_dir or not io.exists(ckpt_dir): + raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") + + # Extract base model. + model: List[nn.Module] = self.model.actor_module + optimizer = self.optimizer + scheduler = self.scheduler + assert ( + len(model) == 1 + ), "Megatron virtual pipeline parallel is not yet supported" + unwrapped_model = model[0] + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + + # Extract sharded state dicts. + sharded_state_dict = {} + model_sharded_state_dict = unwrapped_model.sharded_state_dict() + sharded_state_dict["model"] = model_sharded_state_dict + if optimizer and load_optimizer_states: + sharded_state_dict["optimizer"] = optimizer.sharded_state_dict( + model_sharded_state_dict + ) + if scheduler and load_lr_scheduler_states: + sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + + # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory + # this should be improved to download only the relevant shards for this actor to load + with io.local_read_dir(ckpt_dir) as read_dir: + # Load the checkpoint in parallel. + load_strategy = get_default_load_sharded_strategy(read_dir) + load_strategy = FullyParallelLoadStrategyWrapper( + load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=read_dir, + sharded_strategy=load_strategy, + ) + + # Load the model, optimizer, and scheduler state dicts. + assert ( + "model" in state_dict + ), f"Model state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + model[0].load_state_dict(state_dict["model"], strict=load_module_strict) + self.print("Loaded model state dict.") + + if optimizer and load_optimizer_states: + assert ( + "optimizer" in state_dict + ), f"Optimizer state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + optimizer.load_state_dict(state_dict["optimizer"]) + self.print("Loaded optimizer state dict.") + + if scheduler and load_lr_scheduler_states: + assert ( + "lr_scheduler" in state_dict + ), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" + scheduler.load_state_dict(state_dict["lr_scheduler"]) + self.print("Loaded LR scheduler state dict.") + + # Load RNG state, if present. + if "rng" in state_dict: + self.load_rng_state(state_dict["rng"]) + + return ckpt_dir, {} + + def offload_to_cpu(self): + self.all_buffer_sizes = offload_megatron_grads_to_cpu(self.actor_module) + offload_megatron_model_to_cpu(self.actor_module) + offload_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def backload_to_gpu(self): + load_megatron_grads_to_gpu(self.actor_module) + load_megatron_model_to_gpu(self.actor_module) + load_megatron_optimizer(self.optimizer) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # model init and bridge from huggingface methods: + def init_configs( + self, + model_path, + megatron_config, + transformer_config, + bf16=True, + flash_attn=True, + ): + """ + Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer + """ + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend + transformer_config = asdict(transformer_config) + transformer_config["attention_backend"] = "flash" if flash_attn else "fused" + + bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) + provider = bridge.to_megatron_provider() + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = ( + megatron_config.pipeline_model_parallel_size + ) + provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 + provider.context_parallel_size = megatron_config.context_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = ( + megatron_config.expert_tensor_parallel_size + ) + provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 + provider.attention_backend = "flash" if flash_attn else "fused" + provider.variable_seq_lengths = True + provider.masked_softmax_fusion = True + provider.moe_token_dispatcher_type = "alltoall" + + for k, v in transformer_config.items(): + setattr(provider, k, v) + provider.finalize() + + self.provider = provider + self.bridge = bridge + self.tokenizer = tokenizer + + def make_megatron_module( + self, + wrap_with_ddp: bool = True, + ddp_config: Optional[Dict[str, Any]] = None, + bf16: bool = True, + ) -> List[nn.Module]: + """ + Creates a megatron GPTModel (optionally DDP wrapped) using the bridge. + """ + from megatron.core.distributed.distributed_data_parallel_config import ( + DistributedDataParallelConfig, + ) + + default_ddp_config = DistributedDataParallelConfig() + if wrap_with_ddp: + default_ddp_config.use_distributed_optimizer = True + if ddp_config is not None: + for k, v in ddp_config.items(): + setattr(default_ddp_config, k, v) + model = self.provider.provide_distributed_model( + ddp_config=default_ddp_config, wrap_with_ddp=wrap_with_ddp, bf16=bf16 + ) + return model + + def init_model(self, model_path, num_training_steps: int = 1e9): + """ + Initialize the model, optimizer, and scheduler for the policy worker. + """ + # Store model path for potential recovery + self._model_path = model_path + self._num_training_steps = num_training_steps + + # initialize the bridge and provider objects + self.init_configs( + model_path, + megatron_config=self.cfg.megatron_config, + transformer_config=self.cfg.megatron_config.transformer_config, + bf16=True, + flash_attn=True, + ) + + # wrap with DDP for training + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=asdict(self.cfg.megatron_config.ddp_config), + bf16=True, + ) + + if self._local_rank == 0 and not os.path.exists( + model_path + ): # if not local path, try downloading model weights from huggingface + snapshot_download(model_path) # will be no-op if already downloaded + torch.distributed.barrier() + + # create optimizer + optim_config = init_megatron_optim_config( + asdict(self.cfg.megatron_config.optimizer_config) + ) + self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) + + self._normalize_mini_batch_size() + + # create scheduler + self.scheduler = get_megatron_optimizer_param_scheduler( + optimizer=self.optimizer, + config=asdict(self.cfg.megatron_config.optimizer_config), + num_training_steps=num_training_steps, + ) + + # create worker model + self.model = MegatronModelWrapper( + config=self.cfg, + actor_module=self.actor_module, + actor_optimizer=self.optimizer, + ) + + # NOTE: Set Megatron dist checkpoint async backend to persistent to avoid `os.fork()`-ing + # short-lived background workers, which does not work well with Ray. + ckpt_base.async_calls = AsyncCallsQueue(persistent=True) + + +class MegatronActorGroup: + """ + A group of distributed megatron actors + Functions start with 'async' should return list of object refs + + Args: + cfg: config object for workers + num_nodes (int): Number of nodes for this actor group. + num_gpus_per_node (int): Number of gpus for this actor group. + pg (PlacementGroup, optional): Placement group to schedule actor on. + If none, create new placement group automatically. Defaults to None. + num_gpus_per_actor (float, optional): Number of gpus allocated for each actor. + If < 1.0, multiple models can share same gpu. Defaults to 1. + """ + + def __init__( + self, + cfg, + num_nodes, + num_gpus_per_node, + pg: PlacementGroup, + bundle_indices: List[int], + num_gpus_per_actor: float = 1.0, + resources: Optional[Dict[str, float]] = None, + num_resources_per_node: Optional[int] = None, + ) -> None: + self.cfg = cfg + self._num_nodes = num_nodes + self._num_gpus_per_node = num_gpus_per_node + + # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html + self._resources = resources + self._num_resources_per_node = num_resources_per_node + + self._initiate_actors(pg, num_gpus_per_actor, bundle_indices) + + def _initiate_actors( + self, + pg: Optional[PlacementGroup], + num_gpus_per_actor: float, + bundle_indices: List[int], + ): + """Initialize Ray actors in the worker group. + + Args: + pg: The placement group for the worker group + num_gpus_per_actor: The number of gpus to allocate per actor. + """ + world_size = self._num_nodes * self._num_gpus_per_node + assert pg is not None, "placement group must be provided to MegatronActorGroup" + pg_data = placement_group_table(pg) + assert ( + len(pg_data["bundles"]) >= world_size + ), "the number of bundles in the shared placement group must be greater than or equal to the world size" + + # place master actor on the + master_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[0], + ), + ).remote( + world_size=world_size, + rank=0, + local_rank=0, + master_addr=None, + master_port=None, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + + self._actor_handlers = [master_actor] + # Create worker actors + if world_size > 1: + master_addr, master_port = ray.get( + master_actor.get_master_addr_port.remote() + ) + for rank in range(1, world_size): + local_rank = rank % self._num_gpus_per_node + + worker_actor = MegatronActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_indices[rank], + ), + ).remote( + world_size=world_size, + rank=rank, + local_rank=local_rank, + master_addr=master_addr, + master_port=master_port, + megatron_config=self.cfg.megatron_config, + seed=42, + cfg=self.cfg, + ) + self._actor_handlers.append(worker_actor) + + def initiate_worker_process_group(self): + # Initialize process group + logger.info("Initializing process group for RayActorGroup") + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + logger.info("Initialized process group for RayActorGroup") + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) + + def async_init_model( + self, + *args, + **kwargs, + ) -> List[ObjectRef]: + """Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers. + + Returns: + A list of ray object refs. + """ + return [ + actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers + ] + + def async_run_ray_method( + self, dispatch_type: str, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors using specified dispatch type asynchronously. + + Args: + dispatch_type: Type of dispatch to use ("mesh" or "pass_through") + method_name: Name of the method to call on actors + *args: Positional arguments to pass to the method + **kwargs: Keyword arguments to pass to the method + + Returns: + List of object references + """ + dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type) + # validate the dispatch args to be sent to `.dispatch` + args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs) + + # Dispatch the method call + object_refs = dispatch_class.dispatch( + self.actor_infos, method_name, *args, **kwargs + ) + return object_refs + + def async_run_method_no_dispatch( + self, method_name: str, *args, **kwargs + ) -> List[ObjectRef]: + """Run a method on all actors without dispatching.""" + return [ + getattr(handle, method_name).remote(*args, **kwargs) + for handle in self._actor_handlers + ] + + def _check_actor_alive(self, actor_handle) -> bool: + """Check if an actor is still alive by attempting to call a simple method.""" + try: + # Try to get a simple attribute or call a simple method with timeout + ray.get(actor_handle.get_mesh_rank.remote(), timeout=10) + return True + except Exception: + return False + + def recover_from_failure( + self, backup_actor_group: Optional["MegatronActorGroup"] = None + ): + """Recover from actor failures by removing dead actors and re-initializing process group.""" + logger.info("Starting recovery from actor failure...") + + # Filter out dead actors - both actor_infos and actor_handlers should be in sync + alive_actor_handlers = [] + num_dead_actors = 0 + dead_actor_ranks = [] + + for i, (actor_info, actor_handle) in enumerate( + zip(self.actor_infos, self._actor_handlers) + ): + if self._check_actor_alive(actor_info.handle): + alive_actor_handlers.append(actor_handle) + else: + logger.warning(f"Actor {i} is dead, removing from group") + num_dead_actors += 1 + dead_actor_ranks.append(i) + + if len(alive_actor_handlers) == 0: + raise RuntimeError("All actors are dead, cannot recover") + + if len(alive_actor_handlers) == len(self._actor_handlers): + logger.info("All actors are alive, no recovery needed") + return + + logger.info( + f"Recovering with {len(alive_actor_handlers)}/{len(self._actor_handlers)} actors" + ) + + self._actor_handlers = alive_actor_handlers + + # Destroy existing process groups on alive actors first + logger.info("Destroying old process groups...") + try: + ray.get( + [ + actor.destroy_worker_process_group.remote() + for actor in self._actor_handlers + ] + ) + except Exception as e: + logger.warning( + f"Some errors during process group destruction (may be expected): {e}" + ) + + # if backup actor group is provided, we pop idle actors from the backup actor group and insert them into the current actor group + if backup_actor_group is not None: + logger.info( + f"Popping {num_dead_actors} idle actors from backup actor group" + ) + idle_actor_handles = [ + backup_actor_group._actor_handlers.pop() for _ in range(num_dead_actors) + ] + # let's assume for now that the dead actors are contiguous in the actor group, so we insert the idle actors at the rank of the first dead actor + rank_to_insert = min(dead_actor_ranks) + logger.info(f"Inserting idle actors at rank {rank_to_insert}") + self._actor_handlers = ( + self._actor_handlers[:rank_to_insert] + + idle_actor_handles + + self._actor_handlers[rank_to_insert:] + ) + + # Re-initialize process group with remaining actors + # Update world_size and ranks to match the number of alive actors + new_world_size = len(self._actor_handlers) + + # Update world_size and reassign ranks sequentially (0, 1, 2, ...) + logger.info(f"Updating world_size to {new_world_size} and reassigning ranks...") + update_tasks = [] + for new_rank, actor in enumerate(self._actor_handlers): + update_tasks.append(actor.update_world_size.remote(new_world_size)) + update_tasks.append(actor.update_rank.remote(new_rank)) + ray.get(update_tasks) + + # get master address and a new free port for the new process group + master_addr, _ = ray.get(self._actor_handlers[0].get_master_addr_port.remote()) + master_port = ray.get(self._actor_handlers[0]._get_free_port.remote()) + logger.info(f"Using master_addr={master_addr}, master_port={master_port}") + + # Update master address/port in all actors + ray.get( + [ + actor.update_master_addr_port.remote(master_addr, master_port) + for actor in self._actor_handlers + ] + ) + + # Re-initialize process groups with new world_size and ranks + logger.info( + f"Re-initializing process group with world_size={new_world_size}..." + ) + ray.get( + [actor.init_worker_process_group.remote() for actor in self._actor_handlers] + ) + + # Re-initialize model and optimizer with the new process group + # This is critical because they were created with the old process group + logger.info("Re-initializing model and optimizer with new process group...") + ray.get( + [ + actor.reinit_model_after_recovery.remote() + for actor in self._actor_handlers + ] + ) + + # Update actor_infos with new mesh ranks + self.actor_infos = [ + ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) + for actor in self._actor_handlers + ] + logger.info( + f"Recovery complete. New mesh ranks: {[actor_info.rank for actor_info in self.actor_infos]}" + ) diff --git a/megatron_ray_fault_tolerant/megatron_model_utils.py b/megatron_ray_fault_tolerant/megatron_model_utils.py new file mode 100644 index 0000000..bc5be4a --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_utils.py @@ -0,0 +1,442 @@ +# Utils ported from NeMo-Aligner by way of NeMo-RL +# https://github.com/NVIDIA-NeMo/RL/blob/9301d36cbf847212430b84a27cfe6990f773b7cf/nemo_rl/distributed/model_utils.py#L4 +# The original copyright is reproduced below: + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + + +@torch.no_grad() +def _compute_distributed_log_softmax( + vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup +) -> torch.Tensor: + """Compute a stable distributed log softmax across tensor parallel workers. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP] + where TP is the tensor parallel size. + group (torch.distributed.ProcessGroup): Process group for the all-reduce operations. + + Returns: + torch.Tensor: Log softmax output with the same shape as input, but values represent + log probabilities normalized across the full vocabulary dimension. + """ + logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max + + sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() + + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) + + +class DistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) + softmax_output = log_probs.exp() + + log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) + log_probs[target_mask] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(softmax_output, target_mask, masked_target) + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + softmax, target_mask, masked_target = ctx.saved_tensors + + if softmax.ndim == 3: + B, S, V = softmax.shape + + # skip `torch.nn.functional.one_hot` + row = ( + torch.arange(B, device=softmax.device) + .view(-1, 1) + .expand(-1, S) + .reshape(-1) + ) + col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) + flat_idx = (row * S + col) * V + + flat_chosen = flat_idx.masked_select( + ~target_mask.reshape(-1) + ) + masked_target.masked_select(~target_mask) + + # `neg` is zero-copy + grad_input = softmax.neg() + grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) + + grad_output_selected = grad_output.masked_select(~target_mask) + grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + else: + V = softmax.size(-1) + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V + ) + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +class ChunkedDistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + The log probabilities computation is chunked in the sequence dimension + to mitigate GPU OOM (especially during backward pass). + In addition, logits casting from float16 or bfloat16 -> float32 is performed + inside the chunk loop to avoid materializing a whole float32 logits tensor. + + Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + + log_probs = torch.gather( + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(vocab_parallel_logits, target_mask, masked_target) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + vocab_parallel_logits, target_mask, masked_target = ctx.saved_tensors + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + + partition_vocab_size = int(vocab_parallel_logits.shape[-1]) + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + + all_grad_input = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + softmax_output = softmax_output.exp() + + # 1 if it's the chosen log prob, 0 otherwise + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + + grad_input = is_chosen.float().sub_(softmax_output) + + grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + + all_grad_input.append(grad_input) + + grad_input = torch.cat(all_grad_input, dim=1) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + +def from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, +) -> torch.Tensor: + """Get log probabilities from TP+CP sharded vocab logits. + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] + where TP is the tensor parallel size. + target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. + NOTE: Must be the unmodified targets as this function will shift them internally. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + + Returns: + torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. + The sequence dimension is reduced by 1 due to the target shifting. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + """ + target = target.roll(shifts=-1, dims=-1) + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + pad_len = 0 + # if cp_size > 1: + # Pad the targets to local size * cp_size + pad_len = vocab_parallel_logits.shape[1] * cp_size - target.shape[1] + if pad_len > 0: + target = torch.nn.functional.pad(target, (0, pad_len), value=0) + + # Shard the targets by context parallelism + cp_rank = torch.distributed.get_rank(cp_group) + target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) + + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() + + if cp_size > 1: + # we need to gather the logits by context parallelism + logprobs = allgather_cp_sharded_tensor( + logprobs, cp_group, seq_dim=1 + ) # , unpadded_seqlen=target.shape[1]) + + if pad_len > 0: + logprobs = logprobs[:, :-pad_len] + + return logprobs[:, :-1] + + +def _get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1, +) -> torch.Tensor: + """Get tokens on this context parallelism rank. + + Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + + Args: + input_ids: Input token IDs [seq_length, ] + cp_rank: Context parallelism rank + cp_size: Context parallelism size + + Returns: + Tokens on this context parallelism rank [1, seq_length // cp_size] + """ + if cp_size == 1: + return input_ids + + # load balance for causal attention + shard_size = input_ids.shape[seq_dim] // (cp_size * 2) + shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1) + + # Create slices for each dimension + slices = [slice(None)] * input_ids.dim() + ids_chunks = [] + + for ind in shard_inds: + slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) + ids_chunks.append(input_ids[slices]) + + ids = torch.cat(ids_chunks, dim=seq_dim) + return ids + + +def allgather_cp_sharded_tensor( + tensor, cp_group, seq_dim=1 +): # , unpadded_seqlen=None): + return AllGatherCPTensor.apply(tensor, cp_group, seq_dim) # , unpadded_seqlen) + + +class AllGatherCPTensor(torch.autograd.Function): + def forward( + ctx, tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1 + ): # , unpadded_seqlen: Optional[int] = None): + cp_size = torch.distributed.get_world_size(cp_group) + cp_rank_chunks = [] + for _ in range(cp_size): + cp_rank_chunks.append(torch.empty_like(tensor)) + + torch.distributed.all_gather( + tensor_list=cp_rank_chunks, tensor=tensor, group=cp_group + ) + + # undo the CP load balancing chunking + tensor_chunks = [] + for logit_chunk in cp_rank_chunks: + tensor_chunks.extend(torch.chunk(logit_chunk, chunks=2, dim=seq_dim)) + + chunk_indices = [] + for cp_rank in range(cp_size): + chunk_indices.append(cp_rank) + chunk_indices.append(2 * cp_size - cp_rank - 1) + + chunks_and_indices = list(zip(tensor_chunks, chunk_indices)) + chunks_and_indices = sorted(chunks_and_indices, key=lambda tup: tup[1]) + ret_tensor = [chunk for chunk, _ in chunks_and_indices] + ret_tensor = torch.cat(ret_tensor, dim=seq_dim) + + ctx.seq_dim = seq_dim + ctx.cp_group = cp_group + # ctx.unpadded_seqlen = unpadded_seqlen + + return ret_tensor + + def backward(ctx, grad_output): + cp_size = torch.distributed.get_world_size(ctx.cp_group) + cp_rank = torch.distributed.get_rank(ctx.cp_group) + torch.distributed.all_reduce(grad_output, group=ctx.cp_group) + + # chunk the seqdim in 2*cp chunks, and select with a CP load balanced indexing + seq_dim = ctx.seq_dim + # if ctx.unpadded_seqlen is not None: + # # Zero out grad_output along the seq_dim after unpadded_seqlen + # slicer = [slice(None)] * grad_output.dim() + # slicer[seq_dim] = slice(ctx.unpadded_seqlen, None) + # grad_output[tuple(slicer)] = 0 + + grad_output = grad_output.view( + *grad_output.shape[0:seq_dim], + 2 * cp_size, + grad_output.shape[seq_dim] // (2 * cp_size), + *grad_output.shape[(seq_dim + 1) :], + ) + + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + + grad_input = grad_output.index_select(seq_dim, index) + grad_input = grad_input.view( + *grad_input.shape[0:seq_dim], -1, *grad_input.shape[(seq_dim + 2) :] + ) + + return grad_input, None, None # , None diff --git a/megatron_ray_fault_tolerant/megatron_model_wrapper.py b/megatron_ray_fault_tolerant/megatron_model_wrapper.py new file mode 100644 index 0000000..07e885d --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_model_wrapper.py @@ -0,0 +1,171 @@ +from typing import Optional, List +from functools import partial +import torch +import torch.nn as nn + +from megatron.core.pipeline_parallel import get_forward_backward_func +import megatron.core.parallel_state as mpu +from megatron.core.distributed import finalize_model_grads + +from megatron_model_utils import from_parallel_logits_to_logprobs +from megatron_utils import ( + get_model_config, + make_batch_generator, + preprocess_packed_seqs, + postprocess_packed_seqs, +) +from utils import ppo_policy_loss + + +class MegatronModelWrapper: + def __init__( + self, + config, + actor_module: List[nn.Module], + actor_optimizer: Optional[torch.optim.Optimizer] = None, + ): + self.cfg = config + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + + config = get_model_config(self.actor_module[0]) + # This is set to None by default: https://github.com/NVIDIA/Megatron-LM/blob/07b22a05136a3cb08ece05f7de38cf6aeeb165fb/megatron/core/model_parallel_config.py#L95 + # use the build in finalize_model_grads function to all reduce gradients across parallelism dimensions + config.finalize_model_grads_func = finalize_model_grads + + def train(self): + [module.train() for module in self.actor_module] + + def eval(self): + [module.eval() for module in self.actor_module] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward_backward_mini_batch( + self, + micro_batches: List[dict], + seq_len: int, + micro_batch_size: int, + temperature: float = 1.0, + ) -> List[dict]: + """ + Run forward-backward over a full mini-batch consisting of multiple micro-batches. + + Args: + micro_batches: A list of micro-batch dicts. Each dict must contain keys: + "sequences", "attention_mask", "position_ids", "num_actions", + "old_action_log_probs", "base_action_log_probs", "advantages", + "loss_mask". + seq_len: Sequence length (tokens) per sample (assumed same across micros after padding). + micro_batch_size: Micro-batch size per forward pass. + temperature: Optional temperature for logits scaling. + + Returns: + List[dict]: one metrics dict per micro-batch in order. + """ + forward_backward_func = get_forward_backward_func() + + def loss_func(logits, data): + sequences = data["sequences"] + num_actions = data["num_actions"] + old_action_log_probs = data["old_action_log_probs"] + advantages = data["advantages"] + loss_mask = data["loss_mask"] + + tp_grp = mpu.get_tensor_model_parallel_group() + tp_rank = mpu.get_tensor_model_parallel_rank() + + # temperature normalization + if temperature != 1.0: + logits.div_(temperature) + + token_logprobs = from_parallel_logits_to_logprobs( + logits, + sequences, + vocab_start_index=tp_rank * logits.shape[-1], + vocab_end_index=(tp_rank + 1) * logits.shape[-1], + tp_group=tp_grp, + inference_only=False, + cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` + chunk_size=None, + ) + + action_log_probs = token_logprobs[:, -num_actions:] + + # policy loss should be calculated based on the selected token logprobs + policy_loss, clip_ratio = ppo_policy_loss( + action_log_probs, + old_action_log_probs, + advantages, + config=self.cfg, + loss_mask=loss_mask, + ) + + # no kl loss or entropy loss + loss = policy_loss + + metrics = { + "policy_loss": policy_loss.detach().item(), + "ppo_clip_ratio": clip_ratio, + } + return loss, metrics + + def forward_step(batch_iter, model): + batch = next(batch_iter) + + sequences = batch["sequences"] + attention_mask = batch["attention_mask"].to(bool) + + new_sequences, packed_seq_params = preprocess_packed_seqs( + sequences, + attention_mask, + pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True), + ) + new_attention_mask = None + new_position_ids = None + + outputs = model( + new_sequences, + new_position_ids, + new_attention_mask, + packed_seq_params=packed_seq_params, + ) + + outputs = postprocess_packed_seqs( + outputs, + packed_seq_params, + attention_mask, + micro_batch_size, + seq_len, + post_process=mpu.is_pipeline_last_stage(ignore_virtual=True), + ) + + return outputs, partial(loss_func, data=batch) + + # batch should be a list of micro-batches + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.actor_module) + ) + + metrics_list = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=len(micro_batches), + seq_length=seq_len, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + # broadcast metrics to all pp ranks + if not mpu.is_pipeline_last_stage(ignore_virtual=True): + metrics_list = [None] * len(micro_batches) + with torch.no_grad(): + torch.distributed.broadcast_object_list( + metrics_list, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + + return metrics_list diff --git a/megatron_ray_fault_tolerant/megatron_utils.py b/megatron_ray_fault_tolerant/megatron_utils.py new file mode 100644 index 0000000..4a0f015 --- /dev/null +++ b/megatron_ray_fault_tolerant/megatron_utils.py @@ -0,0 +1,465 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron_utils.py#L4 +# https://github.com/volcengine/verl/blob/dfa3933ac44b545fca1f6a8519fd07394a2cde1c/verl/models/mcore/util.py +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import gc +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.transformer.module import Float16Module +from megatron.core.optimizer import ChainedOptimizer +from megatron.core import parallel_state as mpu +from megatron.core.utils import get_attr_wrapped_model +from megatron.core.packed_seq_params import PackedSeqParams + +ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def make_batch_generator(batches, vpp_size): + """ + Creates a batch generator suitable for Megatron pipeline parallelism, + handling virtual pipeline parallelism (VPP). + + If VPP is used (vpp_size > 1), it duplicates the batch iterator for each + virtual pipeline stage. Otherwise, it returns a single iterator. + + Args: + batches: An iterable (e.g., list) of micro-batches. + vpp_size (int): The virtual pipeline model parallel size. + + Returns: + An iterator or a list of iterators over the micro-batches. + """ + if vpp_size > 1: + # has vpp + batch_generator = [batches] * vpp_size # number of vpp chunks + batch_generator = [iter(b) for b in batch_generator] + else: + # no vpp + batch_generator = iter(batches) + return batch_generator + + +@torch.no_grad() +def offload_megatron_grads_to_cpu(models): + all_buffer_sizes = [] + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + buffer_sizes = [] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.grad_data.storage().size() > 0: + buffer_sizes.append(buffer.grad_data.storage().size()) + buffer.grad_data.storage().resize_(0) + all_buffer_sizes.append(buffer_sizes) + else: + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + return all_buffer_sizes + + +@torch.no_grad() +def load_megatron_grads_to_gpu(models, buffer_sizes): + for i, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for j, buffers in enumerate(model_chunk_all_buffers): + for buffer in buffers: + buffer.grad_data.storage().resize_(buffer_sizes[i][j]) + buffer.grad_data.zero_() + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = ( + buffer.param_data.data.cpu().pin_memory() + ) + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert ( + buffer.param_data_size + == buffer.param_data.cpu_data.storage().size() + ) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_model_to_gpu(models): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_( + buffer.param_data.cpu_data, non_blocking=True + ) + else: + # we need this for ref module + device_id = torch.cuda.current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to("cpu", non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = torch.cuda.current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, "_move_new_state_to_right_device"): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to( + torch.cuda.current_device(), non_blocking=True + ) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to( + torch.cuda.current_device(), non_blocking=True + ) + gc.collect() + torch.cuda.empty_cache() + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = ( + seqlens_in_batch.tolist() + ) # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = ( + seqlens_in_batch_padded.tolist() + ) # lengths after padding + cu_seqlens_padded_cpu: list[int] = ( + cu_seqlens_padded.tolist() + ) # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros( + shape, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[ + i, attention_mask[i] + ] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[ + start_idx + half_seqlen : start_idx + half_seqlen + remain_len + ] = d[remain_start:remain_end] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = ( + attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + ) + + shape = [batch_size, seq_len] + list( + output.shape[2:] + ) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather( + output_list, output.detach(), group=mpu.get_context_parallel_group() + ) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[ + packed_start_idx + + half_seqlen : packed_start_idx + + s_len_padded_chunk + ], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[ + s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen + ] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) diff --git a/megatron_ray_fault_tolerant/optimizer.py b/megatron_ray_fault_tolerant/optimizer.py new file mode 100644 index 0000000..f243397 --- /dev/null +++ b/megatron_ray_fault_tolerant/optimizer.py @@ -0,0 +1,103 @@ +# Utils ported from Verl +# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron/optimizer.py#L4 +# The original copyright is reproduced below: + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.optimizer import OptimizerConfig +from megatron.core.optimizer import ( + get_megatron_optimizer as get_megatron_optimizer_native, +) +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler + + +def init_megatron_optim_config(optim_config) -> OptimizerConfig: + optim_args = { + "optimizer": optim_config.get("optimizer", "adam"), + "lr": optim_config.get("lr", 1.0e-6), + "min_lr": optim_config.get("min_lr", 0.0), + "clip_grad": optim_config.get("max_grad_norm", 1.0), + "weight_decay": optim_config.get("weight_decay", 0.01), + "bf16": True, + "params_dtype": torch.bfloat16, + "use_distributed_optimizer": True, + } + + config = OptimizerConfig(**optim_args) + return config + + +def get_megatron_optimizer( + model, + config: OptimizerConfig, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, +): + # Base optimizer. + return get_megatron_optimizer_native( + config=config, + model_chunks=model, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + ) + + +def get_megatron_optimizer_param_scheduler( + optimizer, + config, + num_training_steps: int = 1e9, # default to a large number for constant lr/wd +): + """ + Get the optimizer parameter scheduler for Megatron. + """ + lr_warmup_steps = config.get("num_warmup_steps", 0) + if config.get("lr_decay_steps", None) is None: + lr_decay_steps = num_training_steps + if config.get("lr_warmup_steps_ratio", None) is not None and ( + config.get("lr_warmup_steps", None) is None + or config.get("lr_warmup_steps", 0) <= 0 + ): + lr_warmup_steps = int(config.get("lr_warmup_steps_ratio", 0.0) * lr_decay_steps) + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=config.get("lr_warmup_init", 0.0), + max_lr=config.get("lr", 1.0e-6), + min_lr=config.get("min_lr", 0.0), + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style="constant", + start_wd=config.get("weight_decay", 0.01), + end_wd=config.get("weight_decay", 0.01), + wd_incr_steps=num_training_steps, + wd_incr_style="constant", + use_checkpoint_opt_param_scheduler=False, + override_opt_param_scheduler=True, + wsd_decay_steps=None, + lr_wsd_decay_style="exponential", + ) + + return opt_param_scheduler + + +def get_megatron_last_lr(optimizer): + """ + Get the last learning rate from the optimizer parameter scheduler. + """ + return optimizer.param_groups[0]["lr"] diff --git a/megatron_ray_fault_tolerant/pyproject.toml b/megatron_ray_fault_tolerant/pyproject.toml new file mode 100644 index 0000000..51be3c3 --- /dev/null +++ b/megatron_ray_fault_tolerant/pyproject.toml @@ -0,0 +1,98 @@ +[project] +name = "ray-ft" +version = "0.0.1" +description = "ray" +authors = [ + {name = "ray", email = "ray@gmail.com"} +] +license = {text = "MIT"} +readme = "README.md" +requires-python = "==3.12.*" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +dependencies = [ + "ninja", + "tensorboard", + "func_timeout", + "transformers>=4.51.0", + "torchdata", + "omegaconf", + "ray==2.51.0", + "peft", + "debugpy==1.8.0", + "hf_transfer", + "wandb", + "datasets==4.0.0", + "flash-attn", + "polars", + "loguru", + "jaxtyping", + "s3fs", + # Make sure to change the flash attention source (under tool.uv.sources) above to a compatible version (<= 2.7.4.post1) for TransformerEngine==2.5.0 + # https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl + # For single node: build transformer-engine separately first, and uncomment the transformer-engine library import below + # uv pip install "torch==2.7.1" + # uv pip install "nvidia-cudnn-cu12>=9.3" + # export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')" + # export CPATH="$CUDNN_PATH/include:${CPATH:-}" + # export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}" + # uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose + # "transformer-engine[pytorch]==2.5.0", + "transformer-engine[pytorch]==2.7.0", + "flash-attn==2.7.4.post1", + "vllm==0.10.1.1", + "torch==2.7.1", + "flashinfer-python", + "torchvision", + "megatron-bridge==0.1.0rc4", + "megatron-core==0.14.0", +] + +[tool.uv] +required-version = ">=0.8.10" +no-build-isolation-package = [ + "transformer-engine-torch", + "transformer-engine", +] + +[tool.uv.extra-build-dependencies] +flash-attn = [{requirement = "torch", match-runtime = true}] +transformer-engine = [{ requirement = "torch", match-runtime = true }, "build_tools"] +transformer-engine-torch = [{ requirement = "torch", match-runtime = true }, "build_tools"] + +[tool.uv.extra-build-variables] +flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} + +[tool.uv.sources] +torch = { index = "pytorch-cu128" } +torchvision = { index = "pytorch-cu128" } +# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run. +# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang +flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" } +flashinfer-python = [ + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" }, + { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" } +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[[tool.uv.index]] +name = "flashinfer-cu128" +url = "https://flashinfer.ai/whl/cu128" +explicit = true + +[tool.setuptools] +include-package-data = true + +[tool.pytest.ini_options] +addopts = "-v -s" +testpaths = [ + "tests", +] \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/run.sh b/megatron_ray_fault_tolerant/run.sh new file mode 100755 index 0000000..c9455a3 --- /dev/null +++ b/megatron_ray_fault_tolerant/run.sh @@ -0,0 +1 @@ +anyscale job submit -f job.yaml \ No newline at end of file diff --git a/megatron_ray_fault_tolerant/training_batch.py b/megatron_ray_fault_tolerant/training_batch.py new file mode 100644 index 0000000..eacdbe6 --- /dev/null +++ b/megatron_ray_fault_tolerant/training_batch.py @@ -0,0 +1,371 @@ +"""Defines interfaces for training data.""" + +from typing import TypedDict, Dict, Any, List, Optional, Generic, TypeVar +import torch +from jaxtyping import Float, Integer +import pickle +import io + +DictType = TypeVar("DictType") + + +# Class inspired by `TensorDict` but is much simpler. +class TensorBatch(dict, Generic[DictType]): + """Base class for training batches + + This defines a generic container for a batch of training data (inputs or outputs). + Consists of a dictionary of tensors along with some metadata. + """ + + metadata: Optional[Dict[str, Any]] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._batch_size = None + self._device = None + self._check_consistency() + + def select( + self, keys: List[str], metadata_keys: Optional[List[str]] = None + ) -> "TensorBatch[DictType]": + """Select a subset of the data batch. + + Args: + keys: The keys to select + metadata_keys: The metadata keys to select + + Returns: + A new `TensorBatch` object with the selected keys and metadata + """ + selected_batch_data = {} + for key in keys: + selected_batch_data[key] = self[key] + selected_metadata = {} + if metadata_keys is None: + selected_metadata = self.metadata + else: + selected_metadata = {} + for key in metadata_keys: + selected_metadata[key] = self.metadata[key] + new_batch = self.__class__(selected_batch_data) + new_batch.metadata = selected_metadata + return new_batch + + def _check_consistency(self): + """Check consistency of all present fields""" + keys = list(self.keys()) + if len(keys) == 0: + return + + batch_size = len(self[keys[0]]) + self._batch_size = batch_size + for key in keys: + value = self[key] + if value is None: + continue + self._device = value.device if self._device is None else self._device + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + if len(value) != batch_size: + raise ValueError(f"Batch size mismatch in {key}") + if value.device != self._device: + raise ValueError( + f"Device mismatch in {key}. Expected {self._device}, got {value.device}" + ) + + def __getitem__(self, index) -> "TensorBatch[DictType]": + if isinstance(index, slice): + return self.slice(index.start, index.stop, index.step) + elif isinstance(index, int): + return self.slice(index, index + 1) + else: + return super().__getitem__(index) + + def __setitem__(self, key: str, value: Optional[torch.Tensor]) -> None: + if value is None: + super().__setitem__(key, value) + return + + if not isinstance(value, torch.Tensor): + raise ValueError(f"Field {key} must be a tensor, got {type(value)}") + + if ( + hasattr(self, "_batch_size") + and self._batch_size is not None + and len(value) != self._batch_size + ): + raise ValueError( + f"Batch size mismatch in {key}. Expected tensor to be of size {self._batch_size}, got {len(value)}." + ) + + super().__setitem__(key, value) + + if hasattr(self, "_batch_size") and self._batch_size is None: + self._batch_size = len(value) + + def to( + self, + device: torch.device = None, + dtype: torch.dtype = None, + *, + non_blocking: bool = False, + ) -> "TensorBatch": + """Move tensors to device and/or cast to dtype. + + Args: + device: The device to move the tensors to + dtype: The dtype to cast the tensors to + non_blocking: Whether the operation should be non-blocking + """ + for key, value in self.items(): + if value is None: + continue + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.to(device, dtype, non_blocking=non_blocking) + return self + + def contiguous(self) -> "TensorBatch": + """Make the tensors contiguous""" + for key, value in self.items(): + if value is None: + continue + # some of these asserts are not needed, but it's kept for type safety + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + self[key] = value.contiguous() + return self + + @property + def batch_size(self) -> int: + """Batch size for the tensors""" + return self._batch_size + + @property + def device(self) -> torch.device: + """Get the device for the tensors""" + return self._device + + def __getstate__(self): + """Serialize the `TensorBatch` object for pickle protocol""" + self.contiguous() + if self._device is not None: + assert self._device == torch.device( + "cpu" + ), "Tensors must be on CPU before serialization" + batch_dict = {} + for key, value in self.items(): + buffer = io.BytesIO() + torch.save(value, buffer) + batch_dict[key] = buffer.getvalue() + + return { + "batch_dict": batch_dict, + "batch_size": self._batch_size, + "device": self._device, + "metadata": self.metadata, + } + + def __setstate__(self, state): + """Deserialize the `TensorBatch` object and load it into memory""" + for key, value in state["batch_dict"].items(): + buffer = io.BytesIO(value) + self[key] = torch.load(buffer) + + self._batch_size = state["batch_size"] + self._device = state["device"] + self.metadata = state["metadata"] + self._check_consistency() + return self + + def repeat(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def repeat_interleave(self, repeats: int): + """Repeat entries in the data batch a specified number of times. + + This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated. + + Args: + repeats: The number of times to repeat the data batch + + Returns: + A new `TensorBatch` object with the data repeated + """ + new_batch = {} + for key, value in self.items(): + if value is None: + new_batch[key] = value + else: + assert isinstance( + value, torch.Tensor + ), f"Field {key} must be a tensor, got {type(value)}" + new_batch[key] = value.repeat_interleave(repeats) + new_batch = self.__class__(new_batch) + new_batch.metadata = self.metadata + return new_batch + + def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]: + """Split into smaller chunks""" + chunks = [] + for i in range(0, self.batch_size, chunk_size): + chunk_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + chunk_data[key] = value[i : i + chunk_size] + else: + raise ValueError( + f"Unsupported type {type(value)} for key {key}" + ) + else: + # `None` values are not chunked + chunk_data[key] = value + chunk = self.__class__(chunk_data) + chunk.metadata = self.metadata + chunks.append(chunk) + return chunks + + def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]": + """Slice the data batch. + + Args: + start: The start index + end: The end index + step: The step size + + Returns: + A new `TensorBatch` object with the view of the specified slice. + """ + slice_obj = slice(start, end, step) + sliced_data = {} + for key, value in self.items(): + if value is not None: + if isinstance(value, torch.Tensor): + sliced_data[key] = value[slice_obj] + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not sliced + sliced_data[key] = value + sliced_batch = self.__class__(sliced_data) + sliced_batch.metadata = self.metadata + return sliced_batch + + def save(self, path: str): + """Save the data to a pickle file""" + with open(path, "wb") as f: + pickle.dump(self, f) + + def load(self, path: str): + """Load the data from a pickle file""" + with open(path, "rb") as f: + return pickle.load(f) + + @classmethod + def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]": + """Concatenate shards. + + Args: + shards: The list of `TensorBatch` objects to cat + + Returns: + A new `TensorBatch` object with the concatenated data + """ + cat_data = {} + assert len(shards) > 0, "Cannot cat an empty list of shards" + for key, value in shards[0].items(): + if value is not None: + if isinstance(value, torch.Tensor): + cat_data[key] = torch.cat([shard[key] for shard in shards]) + else: + raise ValueError(f"Unsupported type {type(value)} for key {key}") + else: + # `None` values are not cat'd + cat_data[key] = value + metadata = shards[0].metadata + cat_batch = cls(cat_data) + cat_batch.metadata = metadata + return cat_batch + + def __len__(self) -> int: + """Length of the batch. + + Note that this is the same as the batch size rather than the number of keys in the batch. + """ + return self._batch_size + + def __eq__(self, other: Any) -> bool: + """Check if two `TensorBatch` objects are equal""" + if not isinstance(other, TensorBatch): + return False + if self.metadata != other.metadata: + return False + if len(self) != len(other): + return False + if len(self.items()) != len(other.items()): + return False + for k, v in self.items(): + if k not in other or not torch.equal(v, other[k]): + return False + return True + + def __str__(self) -> str: + """String representation of the `TensorBatch` object""" + return f"TensorBatch(batch_size={self.batch_size}, device={self.device}, metadata={self.metadata}), items={self.items()}" + + def __repr__(self) -> str: + """String representation of the `TensorBatch` object""" + return self.__str__() + + +class TrainingInput(TypedDict, total=False): + """Schema for training input batch""" + + sequences: Integer[torch.Tensor, "batch_size seq_len"] + attention_mask: Integer[torch.Tensor, "batch_size seq_len"] + loss_mask: Integer[torch.Tensor, "batch_size seq_len"] + response_mask: Integer[torch.Tensor, "batch_size seq_len"] + action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] + values: Optional[Float[torch.Tensor, "batch_size seq_len"]] + returns: Float[torch.Tensor, "batch_size seq_len"] + advantages: Float[torch.Tensor, "batch_size seq_len"] + kl: Float[torch.Tensor, "batch_size seq_len"] + rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + + +class TrainingInputBatch(TensorBatch[TrainingInput]): + """Training input data""" + + pass + + +class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]): + """Training output data""" + + pass diff --git a/megatron_ray_fault_tolerant/utils.py b/megatron_ray_fault_tolerant/utils.py new file mode 100644 index 0000000..e07689d --- /dev/null +++ b/megatron_ray_fault_tolerant/utils.py @@ -0,0 +1,286 @@ +import ray +from ray.util.placement_group import ( + PlacementGroup, + PlacementGroupSchedulingStrategy, + placement_group_table, +) +import torch +from typing import Any, Optional, Dict, List, Union, Tuple +from dataclasses import dataclass +from jaxtyping import Integer, Float +import math +from transformers import AutoTokenizer + + +from training_batch import TrainingInputBatch + +BasicType = Union[int, float, str, bool] + + +@ray.remote(num_gpus=1) +class InfoActor: + def get_gpu_id(self): + return ray.get_gpu_ids()[0] + + +def get_reordered_bundle_indices(pg: PlacementGroup): + """ + Get the reordered bundle indices for a placement group to ensure adjacent ranks are on the same node when possible + """ + pg_data = placement_group_table(pg) + num_bundles = len(pg_data["bundles"]) + bundle_to_node_ids = pg_data["bundles_to_node_id"] + + # use info actor to get the GPU id + info_actors = [] + for i in range(num_bundles): + info_actors.append( + InfoActor.options( + num_cpus=0.01, # set both num_cpus and num_gpus to be small values to enable assignment in colocated case + num_gpus=0.01, + resources=None, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=i, + ), + ).remote() + ) + + gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors]) + for actor in info_actors: + ray.kill(actor) + + # original index, node_id, gpu_id + bundle_infos = [(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)] + pg_reordered_bundle_indices = [ + bundle_info[0] + for bundle_info in sorted(bundle_infos, key=lambda x: (x[1], x[2])) + ] # sort by node_id, then gpu_id + return pg_reordered_bundle_indices + + +def to(tensor: Union[torch.Tensor, List[torch.Tensor], BasicType], device): + if isinstance(tensor, list): + return [to(t, device) for t in tensor] + elif isinstance(tensor, torch.Tensor): + return tensor.to(device) + else: + return tensor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + base_action_log_probs: (B, A) + values: (B, A) + returns: (B, A) + advatanges: (B, A) + attention_mask: (B, S) + action_mask: (B, A) + kl: (B, A) + + "A" is the number of actions/ response length. + """ + + sequences: Integer[torch.Tensor, "batch seq_len"] + action_log_probs: Float[torch.Tensor, "batch response_len"] + base_action_log_probs: Optional[Float[torch.Tensor, "batch response_len"]] + values: Optional[Float[torch.Tensor, "batch response_len"]] + returns: Optional[Float[torch.Tensor, "batch response_len"]] + advantages: Optional[Float[torch.Tensor, "batch response_len"]] + attention_mask: Optional[Integer[torch.LongTensor, "batch seq_len"]] + loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] + action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] + rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + num_actions: int + info: Optional[dict] + kl: Optional[Float[torch.Tensor, "batch response_len"]] = None + metadata: Optional[Dict[str, Any]] = None + + @torch.no_grad() + def to_device(self, device: torch.device) -> None: + self.sequences = to(self.sequences, device) + self.action_log_probs = to(self.action_log_probs, device) + if self.base_action_log_probs is not None: + self.base_action_log_probs = to(self.base_action_log_probs, device) + if self.values is not None: + self.values = to(self.values, device) + if self.returns is not None: + self.returns = to(self.returns, device) + if self.advantages is not None: + self.advantages = to(self.advantages, device) + if self.attention_mask is not None: + self.attention_mask = to(self.attention_mask, device) + if self.loss_mask is not None: + self.loss_mask = to(self.loss_mask, device) + if self.action_mask is not None: + self.action_mask = to(self.action_mask, device) + if self.rollout_logprobs is not None: + self.rollout_logprobs = to(self.rollout_logprobs, device) + + +class BatchIterator: + """A simple iterator to yield micro batches of data from the training batch.""" + + def __init__( + self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False + ): + self.data = data + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = ( + int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + ) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> Experience: + try: + batch = next(self._iter) + exp = self.batch_to_experience(batch) + return exp + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + @staticmethod + def batch_to_experience(batch: TrainingInputBatch): + exp = Experience( + sequences=batch["sequences"], + action_log_probs=batch["action_log_probs"], + base_action_log_probs=batch["base_action_log_probs"], + values=batch["values"], + returns=batch["returns"], + advantages=batch["advantages"], + attention_mask=batch["attention_mask"], + loss_mask=batch["loss_mask"], + action_mask=batch["response_mask"], + num_actions=batch.metadata["response_length"], # int + rollout_logprobs=( + batch["rollout_logprobs"] if "rollout_logprobs" in batch else None + ), + # additional info + # can be used to log metrics etc for micro-batches in the worker + info={}, + # propagate metadata as is + metadata=batch.metadata, + ) + return exp + + +def masked_mean( + tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: Optional[int] = None +) -> torch.Tensor: + if mask is None: + return tensor.mean(axis=dim) + return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim).clamp(min=1.0) + + +def _safe_exp_delta( + delta: torch.Tensor, clip: float = 20.0, out_dtype=None +) -> torch.Tensor: + """ + Clamp the delta before exponentiating to avoid potential overflow. + """ + y = torch.clamp(delta.to(torch.float32), -clip, clip).exp() + return y.to(out_dtype or delta.dtype) + + +def ppo_policy_loss( + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + config, + loss_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, float]: + """Compute dual clip PPO policy loss.""" + ratio = _safe_exp_delta( + log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype + ) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages + loss = -torch.min(surr1, surr2) + clip_ratio = ( + masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() + ) + clip_pg_losses1 = loss + pg_losses3 = -advantages * config.clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + loss = loss = masked_mean(loss, loss_mask) + return loss, clip_ratio + + +def get_test_training_batch(model_name, batch_size=4) -> TrainingInputBatch: + """ + Returns a test training batch with padded seqs and attention masks + + Gives a batch of 4 sequences with variable amounts of left padding, and variable response lengths/amounts of right padding + Attention masks are 1 for non-padding tokens, 0 for padding tokens + The rest of the fields are filled with dummy data + """ + assert batch_size % 4 == 0, "batch size must be divisible by 4" + num_repeats = batch_size // 4 + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + sentences = [ + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.", + "<|im_start|>user\nThe selling price of a bicycle that had sold $220 last year was increased by 15", + "What is the new price? Let's think step by step and output the final answer after `####`.<|im_end|>\n", + "<|im_start|>assistant\nTo find the new price of the bicycle after the increase,", + ] * num_repeats + + sequences = [tokenizer.encode(sentence) for sentence in sentences] + attention_masks = [[1] * len(seq) for seq in sequences] + num_actions = 10 + # max seq len 1 longer than the longest sequence so we always have some padding + max_seq_length = max([len(seq) for seq in sequences]) + 7 + + pad_token_id = tokenizer.pad_token_id + pad_before = [4, 0, 1, 6] * num_repeats + pad_after = [ + max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences) + ] + + for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): + sequences[i] = ( + [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after + ) + attention_masks[i] = [0] * pad_before + attention_masks[i] + [0] * pad_after + + attention_masks = torch.tensor(attention_masks) + sequences = torch.tensor(sequences) + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_masks, + "action_log_probs": torch.tensor([[0.1] * num_actions] * batch_size), + "base_action_log_probs": torch.tensor([[0.2] * num_actions] * batch_size), + "rollout_logprobs": torch.tensor([[0.11] * num_actions] * batch_size), + "values": torch.tensor([[0.1] * num_actions] * batch_size), + "returns": torch.tensor([[0.1] * num_actions] * batch_size), + "advantages": torch.tensor([[0.5] * num_actions] * batch_size), + "loss_mask": torch.tensor([[1] * num_actions] * batch_size), + "response_mask": torch.tensor([[1] * num_actions] * batch_size), + } + ) + data.metadata = {"response_length": num_actions} + return data From 4936b7e7b84ff1dd1bd2265d40b046492cb9fa03 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Mon, 24 Nov 2025 16:10:33 -0800 Subject: [PATCH 10/15] Update megatron_ray_fault_tolerant: job config, main script, and documentation --- megatron_ray_fault_tolerant/CLAUDE.md | 87 +++++++++++++++++++++++++++ megatron_ray_fault_tolerant/job.yaml | 10 +-- megatron_ray_fault_tolerant/main.py | 23 ++++--- 3 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 megatron_ray_fault_tolerant/CLAUDE.md diff --git a/megatron_ray_fault_tolerant/CLAUDE.md b/megatron_ray_fault_tolerant/CLAUDE.md new file mode 100644 index 0000000..cce5be0 --- /dev/null +++ b/megatron_ray_fault_tolerant/CLAUDE.md @@ -0,0 +1,87 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +This is a PPO-style distributed training example using Megatron-Core and Ray with fault tolerance capabilities. The system can automatically recover from actor failures by utilizing backup actors and re-initializing NCCL process groups. + +## Commands + +### Running the example +```bash +uv run --isolated main.py +``` + +### Submit to Anyscale +```bash +anyscale job submit -f job.yaml +``` + +### Linting and formatting +```bash +ruff check --fix . +black . +``` + +## Architecture + +### Core Components + +**MegatronActor** (`megatron_actor.py:60-642`): Ray remote actor wrapping a Megatron model. Each actor: +- Owns one GPU and participates in distributed training +- Manages its own NCCL process group membership +- Handles model init, forward/backward, checkpointing, and recovery + +**MegatronActorGroup** (`megatron_actor.py:644-935`): Manages a collection of MegatronActors: +- Creates and places actors on placement group bundles +- Coordinates distributed operations via dispatch system +- Implements fault recovery by replacing dead actors with backups and re-initializing process groups + +**Dispatch System** (`dispatch.py`): Handles data distribution and result collection: +- `MeshDispatch`: Shards data across DP dimension, collects from primary ranks (SP=0, TP=0, PP=last) +- `PassThroughDispatch`: Broadcasts same data/commands to all actors +- Extensible via `register_dispatch_type()` + +**MegatronModelWrapper** (`megatron_model_wrapper.py`): Wraps Megatron model for PPO training: +- Handles micro-batch accumulation via Megatron's `forward_backward_func` +- Computes PPO policy loss with clipping + +### Data Flow + +1. `TrainingInputBatch` is dispatched to actors (sharded by DP rank for `MeshDispatch`) +2. Each actor runs `ppo_train()` which iterates micro-batches and accumulates gradients +3. Megatron handles gradient sync across TP/PP/DP dimensions +4. Results collected from primary DP ranks and concatenated + +### Parallelism Dimensions + +- **DP (Data Parallel)**: Distributes data across groups +- **TP (Tensor Parallel)**: Splits model tensors within a node +- **PP (Pipeline Parallel)**: Distributes model layers across stages +- **SP/CP (Context Parallel)**: Sequence parallelism for long contexts + +`MeshRank` dataclass tracks each actor's position in the 4D mesh. + +### Fault Recovery Flow + +1. Detect dead actors via `_check_actor_alive()` +2. Destroy old NCCL process groups on healthy actors +3. Pop backup actors from `backup_actor_group` +4. Update world size and reassign ranks +5. Re-initialize process groups and model components +6. Load checkpoint to restore state + +### Cloud Storage + +`file_io.py` provides unified file I/O for local/S3/GCS: +- `local_work_dir()`: Context manager for checkpoint saving (auto-uploads from temp dir) +- `local_read_dir()`: Context manager for checkpoint loading (auto-downloads to temp dir) + +## Key Configuration + +In `main.py`, the `Config` dataclass controls: +- `num_nodes`, `num_gpus_per_node`: Active worker topology +- `num_spare_gpus`: Backup actors for fault tolerance +- `megatron_config`: Parallelism settings (TP, PP, CP sizes) +- `ckpt_dir`: Checkpoint location (supports S3/GCS paths) diff --git a/megatron_ray_fault_tolerant/job.yaml b/megatron_ray_fault_tolerant/job.yaml index f1c2de2..1494dc7 100644 --- a/megatron_ray_fault_tolerant/job.yaml +++ b/megatron_ray_fault_tolerant/job.yaml @@ -12,17 +12,17 @@ containerfile: ./Dockerfile # When empty, Anyscale will auto-select the instance types. You can also specify # minimum and maximum resources. compute_config: - # Pin worker nodes to g6.xlarge (1xL4) so the vision workload lands on L4 GPUs. + # Pin worker nodes to g5.xlarge (A10G) so the fault tolerance workload lands on A10G GPUs. worker_nodes: - - instance_type: g6e.12xlarge + - instance_type: g5.12xlarge min_nodes: 0 - max_nodes: 2 + max_nodes: 9 min_resources: CPU: 0 GPU: 0 max_resources: - CPU: 384 - GPU: 64 + CPU: 432 + GPU: 36 # Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that # will be the working directory for the job. The files in the directory will be diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index 15eb7da..4a55b1d 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -4,6 +4,7 @@ from typing import Optional, List from megatron_actor import MegatronActorGroup from ray.util.placement_group import placement_group +from ray.runtime_env import RuntimeEnv, RuntimeEnvConfig import random import time @@ -50,9 +51,9 @@ class MegatronConfig: @dataclass class Config: - model: str = "Qwen/Qwen3-0.6B" + model: str = "Qwen/Qwen3-4B" # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it - num_nodes: int = 2 + num_nodes: int = 8 num_gpus_per_node: int = 4 mini_batch_size: int = 16 num_spare_gpus: int = 4 @@ -80,11 +81,13 @@ def main(): } ray.init(runtime_env=runtime_env) pg = placement_group( - [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node - + [{"GPU": 1, "CPU": 1}] * config.num_spare_gpus, + [{"GPU": 1, "CPU": 12}] * config.num_nodes * config.num_gpus_per_node + + [{"GPU": 1, "CPU": 12}] * config.num_spare_gpus, strategy="PACK", ) + ray.get(pg.ready(), timeout=1200) + print("Placement group ready") # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 reordered_bundle_indices = get_reordered_bundle_indices(pg) @@ -93,7 +96,7 @@ def main(): num_nodes=config.num_nodes, num_gpus_per_node=config.num_gpus_per_node, pg=pg, - bundle_indices=reordered_bundle_indices[:-config.num_spare_gpus], + bundle_indices=reordered_bundle_indices[: -config.num_spare_gpus], ) actor_group.initiate_worker_process_group() ray.get(actor_group.async_init_model(config.model)) @@ -104,7 +107,7 @@ def main(): num_nodes=config.num_spare_gpus // config.num_gpus_per_node, num_gpus_per_node=config.num_gpus_per_node, pg=pg, - bundle_indices=reordered_bundle_indices[-config.num_spare_gpus:], + bundle_indices=reordered_bundle_indices[-config.num_spare_gpus :], ) # just place but don't initiate the worker process group for the backup actor group # call a function to make sure the actors are placed @@ -129,14 +132,14 @@ def main(): # TODO: add a cpu offload (or cpu save memory) call here # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory # ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) - + # TODO: run another training batch here and save results but don't save checkpoint # randomly kill an actor to simulate fault tolerance scenario # TODO: go deeper into the actor code and throw an exception on a given node and catch it here print("Simulating failure and recovery...") start_time = time.time() - + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) # get the whole dp group associated with the failed actor dp_group_actors = [] @@ -191,7 +194,9 @@ def main(): "pass_through", "ppo_train", batch_after_recovery ) ) - print(f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds") + print( + f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds" + ) print("Recovery successful! Training works with remaining actors.") From 519eb3b84106fd175b4cbbc7f0c666694a8a6b98 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Mon, 24 Nov 2025 16:13:53 -0800 Subject: [PATCH 11/15] Remove CLAUDE.md --- megatron_ray_fault_tolerant/CLAUDE.md | 87 --------------------------- 1 file changed, 87 deletions(-) delete mode 100644 megatron_ray_fault_tolerant/CLAUDE.md diff --git a/megatron_ray_fault_tolerant/CLAUDE.md b/megatron_ray_fault_tolerant/CLAUDE.md deleted file mode 100644 index cce5be0..0000000 --- a/megatron_ray_fault_tolerant/CLAUDE.md +++ /dev/null @@ -1,87 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Overview - -This is a PPO-style distributed training example using Megatron-Core and Ray with fault tolerance capabilities. The system can automatically recover from actor failures by utilizing backup actors and re-initializing NCCL process groups. - -## Commands - -### Running the example -```bash -uv run --isolated main.py -``` - -### Submit to Anyscale -```bash -anyscale job submit -f job.yaml -``` - -### Linting and formatting -```bash -ruff check --fix . -black . -``` - -## Architecture - -### Core Components - -**MegatronActor** (`megatron_actor.py:60-642`): Ray remote actor wrapping a Megatron model. Each actor: -- Owns one GPU and participates in distributed training -- Manages its own NCCL process group membership -- Handles model init, forward/backward, checkpointing, and recovery - -**MegatronActorGroup** (`megatron_actor.py:644-935`): Manages a collection of MegatronActors: -- Creates and places actors on placement group bundles -- Coordinates distributed operations via dispatch system -- Implements fault recovery by replacing dead actors with backups and re-initializing process groups - -**Dispatch System** (`dispatch.py`): Handles data distribution and result collection: -- `MeshDispatch`: Shards data across DP dimension, collects from primary ranks (SP=0, TP=0, PP=last) -- `PassThroughDispatch`: Broadcasts same data/commands to all actors -- Extensible via `register_dispatch_type()` - -**MegatronModelWrapper** (`megatron_model_wrapper.py`): Wraps Megatron model for PPO training: -- Handles micro-batch accumulation via Megatron's `forward_backward_func` -- Computes PPO policy loss with clipping - -### Data Flow - -1. `TrainingInputBatch` is dispatched to actors (sharded by DP rank for `MeshDispatch`) -2. Each actor runs `ppo_train()` which iterates micro-batches and accumulates gradients -3. Megatron handles gradient sync across TP/PP/DP dimensions -4. Results collected from primary DP ranks and concatenated - -### Parallelism Dimensions - -- **DP (Data Parallel)**: Distributes data across groups -- **TP (Tensor Parallel)**: Splits model tensors within a node -- **PP (Pipeline Parallel)**: Distributes model layers across stages -- **SP/CP (Context Parallel)**: Sequence parallelism for long contexts - -`MeshRank` dataclass tracks each actor's position in the 4D mesh. - -### Fault Recovery Flow - -1. Detect dead actors via `_check_actor_alive()` -2. Destroy old NCCL process groups on healthy actors -3. Pop backup actors from `backup_actor_group` -4. Update world size and reassign ranks -5. Re-initialize process groups and model components -6. Load checkpoint to restore state - -### Cloud Storage - -`file_io.py` provides unified file I/O for local/S3/GCS: -- `local_work_dir()`: Context manager for checkpoint saving (auto-uploads from temp dir) -- `local_read_dir()`: Context manager for checkpoint loading (auto-downloads to temp dir) - -## Key Configuration - -In `main.py`, the `Config` dataclass controls: -- `num_nodes`, `num_gpus_per_node`: Active worker topology -- `num_spare_gpus`: Backup actors for fault tolerance -- `megatron_config`: Parallelism settings (TP, PP, CP sizes) -- `ckpt_dir`: Checkpoint location (supports S3/GCS paths) From 88eec3c9752ef9e8d0f22ec41bff65fddb861907 Mon Sep 17 00:00:00 2001 From: xyuzh Date: Mon, 24 Nov 2025 16:21:18 -0800 Subject: [PATCH 12/15] Update main.py configuration --- megatron_ray_fault_tolerant/main.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index 4a55b1d..6ab2059 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -72,14 +72,6 @@ def main(): config = Config() # create placement group including spare gpus - # need to set these env vars to avoid nccl error on nodes not supporting p2p - runtime_env = { - "env_vars": { - "NCCL_P2P_DISABLE": "1", - "NCCL_SHM_DISABLE": "1", - } - } - ray.init(runtime_env=runtime_env) pg = placement_group( [{"GPU": 1, "CPU": 12}] * config.num_nodes * config.num_gpus_per_node + [{"GPU": 1, "CPU": 12}] * config.num_spare_gpus, From eebe92085ef9e8dfcb867a8a28f82fc837984ef1 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 25 Nov 2025 01:08:33 +0000 Subject: [PATCH 13/15] working baseline and in the process of optimizing checkpointing --- megatron_ray_fault_tolerant/file_io.py | 12 +- megatron_ray_fault_tolerant/main.py | 186 +++++++++++++----- megatron_ray_fault_tolerant/megatron_actor.py | 41 ++-- megatron_ray_fault_tolerant/megatron_utils.py | 103 ++++++++-- megatron_ray_fault_tolerant/pyproject.toml | 2 +- megatron_ray_fault_tolerant/timer.py | 22 +++ 6 files changed, 286 insertions(+), 80 deletions(-) create mode 100644 megatron_ray_fault_tolerant/timer.py diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py index 932adbe..ca6324b 100644 --- a/megatron_ray_fault_tolerant/file_io.py +++ b/megatron_ray_fault_tolerant/file_io.py @@ -15,6 +15,7 @@ import fsspec from loguru import logger from datetime import datetime, timezone, timedelta +from typing import List # Optional AWS deps (present when s3fs is installed) try: @@ -218,7 +219,7 @@ def upload_directory(local_path: str, cloud_path: str) -> None: logger.info(f"Uploaded contents of {local_path} to {cloud_path}") -def download_directory(cloud_path: str, local_path: str) -> None: +def download_directory(cloud_path: str, local_path: str, prefixes: List[str] = None) -> None: """Download a cloud directory to local storage.""" if not is_cloud_path(cloud_path): raise ValueError(f"Source must be a cloud path, got: {cloud_path}") @@ -235,6 +236,8 @@ def download_directory(cloud_path: str, local_path: str) -> None: if remote_file.endswith("/"): continue rel_path = remote_file[len(remote_path_stripped) :].lstrip("/") + if prefixes and not any(rel_path.startswith(prefix) for prefix in prefixes): + continue local_file_path = os.path.join(local_path, rel_path) parent_dir = os.path.dirname(local_file_path) if parent_dir: @@ -246,6 +249,8 @@ def download_directory(cloud_path: str, local_path: str) -> None: if remote_file.endswith("/"): continue rel_path = remote_file[len(cloud_path_normalized) :].lstrip("/") + if prefixes and not any(rel_path.startswith(prefix) for prefix in prefixes): + continue local_file_path = os.path.join(local_path, rel_path) parent_dir = os.path.dirname(local_file_path) if parent_dir: @@ -290,7 +295,7 @@ def local_work_dir(output_path: str): @contextmanager -def local_read_dir(input_path: str): +def local_read_dir(input_path: str, prefixes: List[str] = None): """ Context manager that provides a local directory with content from input_path. @@ -299,6 +304,7 @@ def local_read_dir(input_path: str): Args: input_path: The source path (local or cloud) + prefixes: If using cloud path, only download files that start with any of these prefixes. If None, download all files. Yields: str: Local directory path containing the content @@ -311,7 +317,7 @@ def local_read_dir(input_path: str): if is_cloud_path(input_path): with tempfile.TemporaryDirectory() as temp_dir: # Download everything from cloud path to temp_dir - download_directory(input_path, temp_dir) + download_directory(input_path, temp_dir, prefixes) logger.info(f"Downloaded directory contents from {input_path}") yield temp_dir else: diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index 15eb7da..d5e5bf5 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -1,4 +1,5 @@ import os +import argparse from dataclasses import dataclass, field import ray from typing import Optional, List @@ -8,6 +9,7 @@ import random import time from utils import get_test_training_batch, get_reordered_bundle_indices +from timer import Timer @dataclass @@ -51,7 +53,6 @@ class MegatronConfig: @dataclass class Config: model: str = "Qwen/Qwen3-0.6B" - # TODO: test on actually more than 2 nodes for recovery, where we just want to ditch a whole node and replace it num_nodes: int = 2 num_gpus_per_node: int = 4 mini_batch_size: int = 16 @@ -112,88 +113,171 @@ def main(): # train on one batch batch = get_test_training_batch(config.model, batch_size=32) - print("Starting training step 1...") - start_time = time.time() ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) - print(f"Training step 1 took {time.time() - start_time:.2f} seconds") # save checkpoint - start_time = time.time() ray.get( actor_group.async_run_ray_method( "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir ) ) - print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") # TODO: add a cpu offload (or cpu save memory) call here # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory - # ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + print("Saving a copy of the model and optimizer state to cpu memory...") + ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) + ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) # TODO: run another training batch here and save results but don't save checkpoint + # train on one batch + batch = get_test_training_batch(config.model, batch_size=32) + step_2_metrics = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) + # randomly kill an actor to simulate fault tolerance scenario # TODO: go deeper into the actor code and throw an exception on a given node and catch it here - print("Simulating failure and recovery...") - start_time = time.time() + # try to potentially integrate nvidia-fault-tolerance extension here? + with Timer("Simulating failure and recovery with hot spares"): - actor_id = random.randint(0, len(actor_group.actor_infos) - 1) - # get the whole dp group associated with the failed actor - dp_group_actors = [] - for actor_info in actor_group.actor_infos: - if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: - dp_group_actors.append(actor_info) - print( - f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." - ) - for actor_info in dp_group_actors: - ray.kill(actor_info.handle) + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) + # get the whole dp group associated with the failed actor + dp_group_actors = [] + for actor_info in actor_group.actor_infos: + if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: + dp_group_actors.append(actor_info) + print( + f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." + ) + for actor_info in dp_group_actors: + ray.kill(actor_info.handle) - # Destroy process groups on all actors (including dead ones, which will fail gracefully) - print("Destroying old process groups...") - try: + # Destroy process groups on all actors (including dead ones, which will fail gracefully) + print("Destroying old process groups...") + try: + ray.get( + actor_group.async_run_ray_method( + "pass_through", "destroy_worker_process_group" + ) + ) + except Exception as e: + print(f"Some actors failed during destroy (expected): {e}") + + alive_actor_ids = [] + dead_actor_ids = [] + for i, actor_info in enumerate(actor_group.actor_infos): + is_alive = actor_group._check_actor_alive(actor_info.handle) + print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") + if is_alive: + alive_actor_ids.append(i) + else: + dead_actor_ids.append(i) + # Recover from failure: remove dead actors and re-initialize process group + print("Recovering from actor failure...") + actor_group.recover_from_failure(backup_actor_group) + + # load checkpoint on all actors + # TODO: improve the logic here + # we want to only call load checkpoint on the actors that are fresh + # on previously healthy actors we want to restore weights and optimizer state from cpu memory + # ray.get(actor_group.async_run_method_no_dispatch("backload_to_gpu", actor_ids=alive_actor_ids)) + # only for new actors, we want to load the checkpoint ray.get( - actor_group.async_run_ray_method( - "pass_through", "destroy_worker_process_group" + actor_group.async_run_method_no_dispatch( + "load_checkpoint", ckpt_dir=config.ckpt_dir ) ) - except Exception as e: - print(f"Some actors failed during destroy (expected): {e}") - - for i, actor_info in enumerate(actor_group.actor_infos): - is_alive = actor_group._check_actor_alive(actor_info.handle) - print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") - - # Recover from failure: remove dead actors and re-initialize process group - print("Recovering from actor failure...") - actor_group.recover_from_failure(backup_actor_group) - - # load checkpoint on all actors - # TODO: improve the logic here - # we want to only call load checkpoint on the actors that are fresh - # on previously healthy actors we want to restore weights and optimizer state from cpu memory - # ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu"), actor_ids=[previously healthy actor ids]) - # only for new actors, we want to load the checkpoint - ray.get( - actor_group.async_run_ray_method( - "pass_through", "load_checkpoint", ckpt_dir=config.ckpt_dir - ) - ) - print(f"Recovery took {time.time() - start_time:.2f} seconds") # TODO: check that results here are the same as before the failure when resuming from checkpoint # Test that training still works after recovery print("Testing training after recovery...") batch_after_recovery = get_test_training_batch(config.model, batch_size=32) - start_time = time.time() ray.get( actor_group.async_run_ray_method( "pass_through", "ppo_train", batch_after_recovery ) ) - print(f"Training step 2 (after recovery) took {time.time() - start_time:.2f} seconds") print("Recovery successful! Training works with remaining actors.") +def baseline(): + config = Config() + # create placement group including spare gpus + + # need to set these env vars to avoid nccl error on nodes not supporting p2p + runtime_env = { + "env_vars": { + "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + } + } + ray.init(runtime_env=runtime_env) + pg = placement_group( + [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node, + strategy="PACK", + ) + ray.get(pg.ready(), timeout=1200) + # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 + reordered_bundle_indices = get_reordered_bundle_indices(pg) + + actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_nodes, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices, + ) + actor_group.initiate_worker_process_group() + ray.get(actor_group.async_init_model(config.model)) + + batch = get_test_training_batch(config.model, batch_size=32) + ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) + + # save checkpoint + start_time = time.time() + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") + + # simulate full teardown and restart + with Timer("Full teardown and restart"): + ray.shutdown() + ray.init(runtime_env=runtime_env) + pg = placement_group( + [{"GPU": 1, "CPU": 1}] * config.num_nodes * config.num_gpus_per_node, + strategy="PACK", + ) + ray.get(pg.ready(), timeout=1200) + # this is needed because placement group gpu bundle order is not deterministic: https://github.com/ray-project/ray/issues/51117 + reordered_bundle_indices = get_reordered_bundle_indices(pg) + + actor_group = MegatronActorGroup( + cfg=config, + num_nodes=config.num_nodes, + num_gpus_per_node=config.num_gpus_per_node, + pg=pg, + bundle_indices=reordered_bundle_indices, + ) + actor_group.initiate_worker_process_group() + ray.get(actor_group.async_init_model(config.model)) + + ray.get( + actor_group.async_run_method_no_dispatch( + "load_checkpoint", ckpt_dir=config.ckpt_dir + ) + ) + + batch = get_test_training_batch(config.model, batch_size=32) + ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, default="main") + args = parser.parse_args() + if args.mode == "main": + main() + elif args.mode == "baseline": + baseline() + else: + raise ValueError(f"Invalid mode: {args.mode}") diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py index e6d330c..304c7f6 100644 --- a/megatron_ray_fault_tolerant/megatron_actor.py +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -46,9 +46,10 @@ from megatron_model_wrapper import MegatronModelWrapper from megatron_utils import ( offload_megatron_model_to_cpu, - offload_megatron_optimizer, + snapshot_optimizer_state_cpu, load_megatron_model_to_gpu, load_megatron_optimizer, + apply_optimizer_state_snapshot, offload_megatron_grads_to_cpu, load_megatron_grads_to_gpu, ) @@ -330,6 +331,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": ) micro_buffer = [] + all_metrics = [] for local_step, experience in enumerate(pbar): experience.to_device(torch.cuda.current_device()) sequences = experience.sequences @@ -360,18 +362,20 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": seq_len = micro_buffer[0]["sequences"].shape[1] micro_bsz = micro_buffer[0]["sequences"].shape[0] - self.model.forward_backward_mini_batch( + metrics = self.model.forward_backward_mini_batch( micro_batches=micro_buffer, seq_len=seq_len, micro_batch_size=micro_bsz, ) - + all_metrics.extend(metrics) _, grad_norm, _ = self.optimizer.step() self.scheduler.step(1) self.optimizer.zero_grad() torch.distributed.barrier() + return all_metrics + def save_checkpoint(self, ckpt_dir: str): # Extract base model. model: List[nn.Module] = self.model.actor_module @@ -409,7 +413,7 @@ def save_checkpoint(self, ckpt_dir: str): # Save the checkpoint across ranks in parallel. save_strategy = get_default_save_sharded_strategy("torch_dist") save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + save_strategy, mpu.get_model_parallel_group() ) with io.local_work_dir(ckpt_dir) as work_dir: @@ -464,11 +468,12 @@ def load_checkpoint( # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory # this should be improved to download only the relevant shards for this actor to load + # prefixes=[f"__{self._rank}_", ".metadata", "common.pt", "metadata.json"] with io.local_read_dir(ckpt_dir) as read_dir: # Load the checkpoint in parallel. load_strategy = get_default_load_sharded_strategy(read_dir) load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + load_strategy, mpu.get_model_parallel_group() ) state_dict = dist_checkpointing.load( sharded_state_dict=sharded_state_dict, @@ -505,15 +510,16 @@ def load_checkpoint( def offload_to_cpu(self): self.all_buffer_sizes = offload_megatron_grads_to_cpu(self.actor_module) - self.all_model_weights_and_sizes = offload_megatron_model_to_cpu(self.actor_module) - self.all_optimizer_weights_and_sizes = offload_megatron_optimizer(self.optimizer) + self.all_model_buffers_param_data, self.all_model_buffers_param_data_sizes = offload_megatron_model_to_cpu(self.actor_module) + self.all_optimizer_state_dict = snapshot_optimizer_state_cpu(self.optimizer) torch.cuda.synchronize() torch.cuda.empty_cache() def backload_to_gpu(self): load_megatron_grads_to_gpu(self.actor_module, self.all_buffer_sizes) - load_megatron_model_to_gpu(self.actor_module, self.all_model_weights_and_sizes) - load_megatron_optimizer(self.optimizer, self.all_optimizer_weights_and_sizes) + load_megatron_model_to_gpu(self.actor_module, self.all_model_buffers_param_data, self.all_model_buffers_param_data_sizes) + apply_optimizer_state_snapshot(self.optimizer, self.all_optimizer_state_dict) + load_megatron_optimizer(self.optimizer) torch.cuda.synchronize() torch.cuda.empty_cache() @@ -800,13 +806,20 @@ def async_run_ray_method( return object_refs def async_run_method_no_dispatch( - self, method_name: str, *args, **kwargs + self, method_name: str, actor_ids: List[int] = None, *args, **kwargs ) -> List[ObjectRef]: """Run a method on all actors without dispatching.""" - return [ - getattr(handle, method_name).remote(*args, **kwargs) - for handle in self._actor_handlers - ] + if actor_ids is None: + return [ + getattr(handle, method_name).remote(*args, **kwargs) + for handle in self._actor_handlers + ] + else: + object_refs = [] + for i, handle in enumerate(self._actor_handlers): + if i in actor_ids: + object_refs.append(getattr(handle, method_name).remote(*args, **kwargs)) + return object_refs def _check_actor_alive(self, actor_handle) -> bool: """Check if an actor is still alive by attempting to call a simple method.""" diff --git a/megatron_ray_fault_tolerant/megatron_utils.py b/megatron_ray_fault_tolerant/megatron_utils.py index e1e3175..4d70171 100644 --- a/megatron_ray_fault_tolerant/megatron_utils.py +++ b/megatron_ray_fault_tolerant/megatron_utils.py @@ -22,6 +22,7 @@ import torch import gc +from typing import Any, Dict, List, Union from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.transformer.module import Float16Module from megatron.core.optimizer import ChainedOptimizer @@ -90,10 +91,12 @@ def load_megatron_grads_to_gpu(models, buffer_sizes): model_chunk.buffers, model_chunk.expert_parallel_buffers, ] - for j, buffers in enumerate(model_chunk_all_buffers): + j = 0 + for buffers in model_chunk_all_buffers: for buffer in buffers: buffer.grad_data.storage().resize_(buffer_sizes[i][j]) buffer.grad_data.zero_() + j += 1 else: # we need this for ref module for _, param in model_chunk.named_parameters(): @@ -114,51 +117,59 @@ def offload_megatron_model_to_cpu(models): - fp32 main_parameter chunked in model and dp group - fp32 optimizer state chunked in model and dp group """ - # all_model_weights + all_model_buffers_param_data = [] + all_model_buffers_param_data_sizes = [] for model_chunk in models: if isinstance(model_chunk, DDP): model_chunk_all_buffers = [ model_chunk.buffers, model_chunk.expert_parallel_buffers, ] + model_chunk_buffers_param_data = [] + model_chunk_buffers_param_data_sizes = [] for buffers in model_chunk_all_buffers: for buffer in buffers: # offload parameters if buffer.param_data.storage().size() > 0: - buffer.param_data.cpu_data = ( + model_chunk_buffers_param_data.append( buffer.param_data.data.cpu().pin_memory() ) - buffer.param_data_size = buffer.param_data.storage().size() + model_chunk_buffers_param_data_sizes.append(buffer.param_data.storage().size()) buffer.param_data.storage().resize_(0) assert ( - buffer.param_data_size - == buffer.param_data.cpu_data.storage().size() + model_chunk_buffers_param_data_sizes[-1] + == model_chunk_buffers_param_data[-1].data.storage().size() ) + all_model_buffers_param_data.append(model_chunk_buffers_param_data) + all_model_buffers_param_data_sizes.append(model_chunk_buffers_param_data_sizes) else: # we need this for ref module for _, param in model_chunk.named_parameters(): param.data = param.data.to("cpu", non_blocking=True) gc.collect() torch.cuda.empty_cache() + return all_model_buffers_param_data, all_model_buffers_param_data_sizes @torch.no_grad() -def load_megatron_model_to_gpu(models): - for model_chunk in models: +def load_megatron_model_to_gpu(models, all_model_buffers_param_data, all_model_buffers_param_data_sizes): + for i, model_chunk in enumerate(models): if isinstance(model_chunk, DDP): model_chunk_all_buffers = [ model_chunk.buffers, model_chunk.expert_parallel_buffers, ] + j = 0 for buffers in model_chunk_all_buffers: for buffer in buffers: if buffer.param_data.storage().size() == 0: - buffer.param_data.storage().resize_(buffer.param_data_size) + buffer.param_data.storage().resize_(all_model_buffers_param_data_sizes[i][j]) # copy data from cpu to cuda buffer.param_data.copy_( - buffer.param_data.cpu_data, non_blocking=True + all_model_buffers_param_data[i][j], non_blocking=True ) + j += 1 else: # we need this for ref module device_id = torch.cuda.current_device() @@ -167,7 +178,6 @@ def load_megatron_model_to_gpu(models): gc.collect() torch.cuda.empty_cache() - @torch.no_grad() def offload_megatron_copy_params(optimizers): """ @@ -269,6 +279,77 @@ def _iter_opts(opt): torch.cuda.empty_cache() +@torch.no_grad() +def snapshot_optimizer_state_cpu( + optimizers: Union[torch.optim.Optimizer, ChainedOptimizer] +) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + Offload optimizer tensors to CPU and return a Python object snapshot of the state. + The returned object can later be used to restore/override a freshly created optimizer. + Supports both single Megatron optimizers and `ChainedOptimizer`. + + Returns: + - dict when a single optimizer is provided + - list[dict] when a ChainedOptimizer is provided (one dict per underlying optimizer) + """ + # Ensure all relevant optimizer tensors are on CPU first (params and state) + offload_megatron_optimizer(optimizers) + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + snapshots: List[Dict[str, Any]] = [] + for _opt in _iter_opts(optimizers): + base_opt = getattr(_opt, "optimizer", _opt) + # state_dict() returns a pure-Python object referencing tensors (now on CPU) + snapshots.append(base_opt.state_dict()) + + if isinstance(optimizers, ChainedOptimizer): + return snapshots + return snapshots[0] + + +@torch.no_grad() +def apply_optimizer_state_snapshot( + optimizers: Union[torch.optim.Optimizer, ChainedOptimizer], + snapshot: Union[Dict[str, Any], List[Dict[str, Any]]], +) -> None: + """ + Apply a previously captured CPU snapshot to the provided optimizer(s), + overriding their state. Supports both single optimizers and `ChainedOptimizer`. + + If the underlying optimizer supports device specialization (e.g., HybridDeviceOptimizer), + this will also move newly loaded states to the correct device. + """ + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + if isinstance(optimizers, ChainedOptimizer): + if not isinstance(snapshot, (list, tuple)): + raise ValueError("Expected a list of state_dicts for ChainedOptimizer snapshot.") + if len(snapshot) != len(optimizers.chained_optimizers): + raise ValueError( + f"Snapshot length ({len(snapshot)}) does not match number of chained optimizers " + f"({len(optimizers.chained_optimizers)})." + ) + items = zip(optimizers.chained_optimizers, snapshot) + else: + if not isinstance(snapshot, dict): + raise ValueError("Expected a dict state_dict snapshot for a single optimizer.") + items = [ (optimizers, snapshot) ] + + for _opt, _state in items: + base_opt = getattr(_opt, "optimizer", _opt) + base_opt.load_state_dict(_state) + # Align newly loaded states to the intended device if available (Megatron HybridDeviceOptimizer) + if hasattr(base_opt, "_move_new_state_to_right_device"): + base_opt._move_new_state_to_right_device() + + @torch.no_grad() def load_megatron_optimizer(optimizers): def _iter_opts(opt): diff --git a/megatron_ray_fault_tolerant/pyproject.toml b/megatron_ray_fault_tolerant/pyproject.toml index fb43490..38cff67 100644 --- a/megatron_ray_fault_tolerant/pyproject.toml +++ b/megatron_ray_fault_tolerant/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "transformers>=4.51.0", "torchdata", "omegaconf", - "ray==2.51.0", + "ray==2.51.1", "peft", "debugpy==1.8.0", "hf_transfer", diff --git a/megatron_ray_fault_tolerant/timer.py b/megatron_ray_fault_tolerant/timer.py new file mode 100644 index 0000000..df720ff --- /dev/null +++ b/megatron_ray_fault_tolerant/timer.py @@ -0,0 +1,22 @@ +import time +from loguru import logger + +class Timer: + def __init__(self, message): + self.message = message + + def __enter__(self): + self.start_time = time.time() + logger.opt(depth=1).info(f"Started: '{self.message}'") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.opt(depth=1).info(f"Finished: '{self.message}', time cost: {time.time() - self.start_time:.2f}s") + + async def __aenter__(self): + self.start_time = time.time() + logger.opt(depth=1).info(f"Started: '{self.message}'") + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + logger.opt(depth=1).info(f"Finished: '{self.message}', time cost: {time.time() - self.start_time:.2f}s") \ No newline at end of file From 760b93e53451fbdbebc4b6f5f18cc088975ae819 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 26 Nov 2025 19:39:18 +0000 Subject: [PATCH 14/15] fast(ish) model checkpointing/loading with s3 --- megatron_ray_fault_tolerant/file_io.py | 19 ++- megatron_ray_fault_tolerant/main.py | 112 ++++++++++-------- megatron_ray_fault_tolerant/megatron_actor.py | 14 +-- 3 files changed, 81 insertions(+), 64 deletions(-) diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py index ca6324b..67f6a97 100644 --- a/megatron_ray_fault_tolerant/file_io.py +++ b/megatron_ray_fault_tolerant/file_io.py @@ -15,7 +15,7 @@ import fsspec from loguru import logger from datetime import datetime, timezone, timedelta -from typing import List +from typing import List, Optional # Optional AWS deps (present when s3fs is installed) try: @@ -295,7 +295,7 @@ def local_work_dir(output_path: str): @contextmanager -def local_read_dir(input_path: str, prefixes: List[str] = None): +def local_read_dir(input_path: str, local_path: Optional[str] = None, prefixes: List[str] = None): """ Context manager that provides a local directory with content from input_path. @@ -304,6 +304,7 @@ def local_read_dir(input_path: str, prefixes: List[str] = None): Args: input_path: The source path (local or cloud) + local_path: The local path to download the directory to. If None, use a temporary directory. prefixes: If using cloud path, only download files that start with any of these prefixes. If None, download all files. Yields: @@ -315,11 +316,17 @@ def local_read_dir(input_path: str, prefixes: List[str] = None): model = AutoModel.from_pretrained(read_dir) """ if is_cloud_path(input_path): - with tempfile.TemporaryDirectory() as temp_dir: - # Download everything from cloud path to temp_dir - download_directory(input_path, temp_dir, prefixes) + if local_path is None: + with tempfile.TemporaryDirectory() as temp_dir: + # Download everything from cloud path to temp_dir + download_directory(input_path, temp_dir, prefixes) + logger.info(f"Downloaded directory contents from {input_path}") + yield temp_dir + else: + # Download everything from cloud path to local_path + download_directory(input_path, local_path, prefixes) logger.info(f"Downloaded directory contents from {input_path}") - yield temp_dir + yield local_path else: # For local paths, use directly (but check it exists) if not exists(input_path): diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index 1876c79..b8d8d5d 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -61,8 +61,9 @@ class Config: micro_train_batch_size_per_gpu: int = 2 megatron_config: MegatronConfig = field(default_factory=MegatronConfig) ckpt_dir: str = ( - os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt3/" + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt4/" ) + local_ckpt_dir: str = "/tmp/megatron_fault_tolerance/ckpt4/" # used for storing checkpoint on local disk when downloading from cloud # algorithm config eps_clip_low: float = 0.2 eps_clip_high: float = 0.2 @@ -72,6 +73,14 @@ class Config: def main(): config = Config() # create placement group including spare gpus + # need to set these env vars to avoid nccl error on nodes not supporting p2p + runtime_env = { + "env_vars": { + "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + } + } + ray.init(runtime_env=runtime_env) pg = placement_group( [{"GPU": 1, "CPU": 12}] * config.num_nodes * config.num_gpus_per_node @@ -119,6 +128,7 @@ def main(): # TODO: add a cpu offload (or cpu save memory) call here # in order for the healthy actors to save a copy of the model and optimizer state to cpu memory + # figure out if it's possible to do keep this on GPU instead of CPU print("Saving a copy of the model and optimizer state to cpu memory...") ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) @@ -133,42 +143,45 @@ def main(): # TODO: go deeper into the actor code and throw an exception on a given node and catch it here # try to potentially integrate nvidia-fault-tolerance extension here? with Timer("Simulating failure and recovery with hot spares"): - - actor_id = random.randint(0, len(actor_group.actor_infos) - 1) - # get the whole dp group associated with the failed actor - dp_group_actors = [] - for actor_info in actor_group.actor_infos: - if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: - dp_group_actors.append(actor_info) - print( - f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." - ) - for actor_info in dp_group_actors: - ray.kill(actor_info.handle) + + with Timer("Killing actors to simulate failure"): + actor_id = random.randint(0, len(actor_group.actor_infos) - 1) + # get the whole dp group associated with the failed actor + dp_group_actors = [] + for actor_info in actor_group.actor_infos: + if actor_info.rank.dp == actor_group.actor_infos[actor_id].rank.dp: + dp_group_actors.append(actor_info) + print( + f"Killing actors {[actor_info.rank for actor_info in dp_group_actors]} to simulate failure..." + ) + for actor_info in dp_group_actors: + ray.kill(actor_info.handle) # Destroy process groups on all actors (including dead ones, which will fail gracefully) - print("Destroying old process groups...") - try: - ray.get( - actor_group.async_run_ray_method( - "pass_through", "destroy_worker_process_group" + with Timer("Destroying old process groups..."): + try: + ray.get( + actor_group.async_run_ray_method( + "pass_through", "destroy_worker_process_group" + ) ) - ) - except Exception as e: - print(f"Some actors failed during destroy (expected): {e}") - - alive_actor_ids = [] - dead_actor_ids = [] - for i, actor_info in enumerate(actor_group.actor_infos): - is_alive = actor_group._check_actor_alive(actor_info.handle) - print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") - if is_alive: - alive_actor_ids.append(i) - else: - dead_actor_ids.append(i) - # Recover from failure: remove dead actors and re-initialize process group - print("Recovering from actor failure...") - actor_group.recover_from_failure(backup_actor_group) + except Exception as e: + print(f"Some actors failed during destroy (expected): {e}") + + with Timer("Checking actor status..."): + alive_actor_ids = [] + dead_actor_ids = [] + for i, actor_info in enumerate(actor_group.actor_infos): + is_alive = actor_group._check_actor_alive(actor_info.handle) + print(f"Actor {i} (handle: {actor_info.handle}) is alive: {is_alive}") + if is_alive: + alive_actor_ids.append(i) + else: + dead_actor_ids.append(i) + + with Timer("Recovering from actor failure..."): + # Recover from failure: remove dead actors and re-initialize process group + actor_group.recover_from_failure(backup_actor_group) # load checkpoint on all actors # TODO: improve the logic here @@ -176,11 +189,12 @@ def main(): # on previously healthy actors we want to restore weights and optimizer state from cpu memory # ray.get(actor_group.async_run_method_no_dispatch("backload_to_gpu", actor_ids=alive_actor_ids)) # only for new actors, we want to load the checkpoint - ray.get( - actor_group.async_run_method_no_dispatch( - "load_checkpoint", ckpt_dir=config.ckpt_dir + with Timer("Loading checkpoint..."): + ray.get( + actor_group.async_run_method_no_dispatch( + "load_checkpoint", ckpt_dir=config.ckpt_dir + ) ) - ) # TODO: check that results here are the same as before the failure when resuming from checkpoint # Test that training still works after recovery @@ -227,14 +241,12 @@ def baseline(): ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) # save checkpoint - start_time = time.time() - ray.get( - actor_group.async_run_ray_method( - "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + with Timer("Saving checkpoint..."): + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=config.ckpt_dir + ) ) - ) - print(f"Checkpoint saving took {time.time() - start_time:.2f} seconds") - # simulate full teardown and restart with Timer("Full teardown and restart"): ray.shutdown() @@ -256,12 +268,12 @@ def baseline(): ) actor_group.initiate_worker_process_group() ray.get(actor_group.async_init_model(config.model)) - - ray.get( - actor_group.async_run_method_no_dispatch( - "load_checkpoint", ckpt_dir=config.ckpt_dir + with Timer("Loading checkpoint..."): + ray.get( + actor_group.async_run_method_no_dispatch( + "load_checkpoint", ckpt_dir=config.ckpt_dir + ) ) - ) batch = get_test_training_batch(config.model, batch_size=32) ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py index 304c7f6..95653b2 100644 --- a/megatron_ray_fault_tolerant/megatron_actor.py +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -412,9 +412,7 @@ def save_checkpoint(self, ckpt_dir: str): # Save the checkpoint across ranks in parallel. save_strategy = get_default_save_sharded_strategy("torch_dist") - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, mpu.get_model_parallel_group() - ) + save_strategy = FullyParallelSaveStrategyWrapper(save_strategy) with io.local_work_dir(ckpt_dir) as work_dir: # synchronous checkpointing for now @@ -466,19 +464,19 @@ def load_checkpoint( if scheduler and load_lr_scheduler_states: sharded_state_dict["lr_scheduler"] = scheduler.state_dict() + prefixes=[f"__{self._rank}_", ".metadata", "common.pt", "metadata.json"] + # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory # this should be improved to download only the relevant shards for this actor to load - # prefixes=[f"__{self._rank}_", ".metadata", "common.pt", "metadata.json"] - with io.local_read_dir(ckpt_dir) as read_dir: + with io.local_read_dir(ckpt_dir, local_path=self.cfg.local_ckpt_dir, prefixes=prefixes) as read_dir: # Load the checkpoint in parallel. load_strategy = get_default_load_sharded_strategy(read_dir) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_model_parallel_group() - ) + load_strategy = FullyParallelLoadStrategyWrapper(load_strategy) state_dict = dist_checkpointing.load( sharded_state_dict=sharded_state_dict, checkpoint_dir=read_dir, sharded_strategy=load_strategy, + strict="assume_ok_unexpected", ) # Load the model, optimizer, and scheduler state dicts. From ea97b77ed535f346173a5d3438112f86f50fa1a8 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 27 Nov 2025 01:30:15 +0000 Subject: [PATCH 15/15] test with 8b on 4 nodes - models with tied word embeddings might fail with > 2 nodes now, but that is ok --- megatron_ray_fault_tolerant/file_io.py | 14 ++++++++------ megatron_ray_fault_tolerant/main.py | 8 ++++---- megatron_ray_fault_tolerant/megatron_actor.py | 1 + 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/megatron_ray_fault_tolerant/file_io.py b/megatron_ray_fault_tolerant/file_io.py index 67f6a97..545aa8a 100644 --- a/megatron_ray_fault_tolerant/file_io.py +++ b/megatron_ray_fault_tolerant/file_io.py @@ -16,7 +16,7 @@ from loguru import logger from datetime import datetime, timezone, timedelta from typing import List, Optional - +from timer import Timer # Optional AWS deps (present when s3fs is installed) try: import botocore.session as _botocore_session @@ -286,8 +286,9 @@ def local_work_dir(output_path: str): yield temp_dir finally: # Upload everything from temp_dir to cloud path - upload_directory(temp_dir, output_path) - logger.info(f"Uploaded directory contents to {output_path}") + with Timer("Uploading directory contents to cloud path..."): + upload_directory(temp_dir, output_path) + logger.info(f"Uploaded directory contents to {output_path}") else: # For local paths, ensure directory exists and use it directly makedirs(output_path, exist_ok=True) @@ -324,9 +325,10 @@ def local_read_dir(input_path: str, local_path: Optional[str] = None, prefixes: yield temp_dir else: # Download everything from cloud path to local_path - download_directory(input_path, local_path, prefixes) - logger.info(f"Downloaded directory contents from {input_path}") - yield local_path + with Timer("Downloading directory contents to cloud path..."): + download_directory(input_path, local_path, prefixes) + logger.info(f"Downloaded directory contents from {input_path}") + yield local_path else: # For local paths, use directly (but check it exists) if not exists(input_path): diff --git a/megatron_ray_fault_tolerant/main.py b/megatron_ray_fault_tolerant/main.py index b8d8d5d..da86aef 100644 --- a/megatron_ray_fault_tolerant/main.py +++ b/megatron_ray_fault_tolerant/main.py @@ -53,17 +53,17 @@ class MegatronConfig: @dataclass class Config: - model: str = "Qwen/Qwen3-0.6B" - num_nodes: int = 2 + model: str = "Qwen/Qwen3-8B" + num_nodes: int = 4 num_gpus_per_node: int = 4 mini_batch_size: int = 16 num_spare_gpus: int = 4 micro_train_batch_size_per_gpu: int = 2 megatron_config: MegatronConfig = field(default_factory=MegatronConfig) ckpt_dir: str = ( - os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt4/" + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/megatron_fault_tolerance/ckpt6/" ) - local_ckpt_dir: str = "/tmp/megatron_fault_tolerance/ckpt4/" # used for storing checkpoint on local disk when downloading from cloud + local_ckpt_dir: str = "/tmp/megatron_fault_tolerance/ckpt6/" # used for storing checkpoint on local disk when downloading from cloud # algorithm config eps_clip_low: float = 0.2 eps_clip_high: float = 0.2 diff --git a/megatron_ray_fault_tolerant/megatron_actor.py b/megatron_ray_fault_tolerant/megatron_actor.py index 95653b2..97a5903 100644 --- a/megatron_ray_fault_tolerant/megatron_actor.py +++ b/megatron_ray_fault_tolerant/megatron_actor.py @@ -469,6 +469,7 @@ def load_checkpoint( # currently, if the ckpt_dir is a cloud path, we download all the contents of the cloud path to a local directory # this should be improved to download only the relevant shards for this actor to load with io.local_read_dir(ckpt_dir, local_path=self.cfg.local_ckpt_dir, prefixes=prefixes) as read_dir: + dist.barrier() # Load the checkpoint in parallel. load_strategy = get_default_load_sharded_strategy(read_dir) load_strategy = FullyParallelLoadStrategyWrapper(load_strategy)