Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/neural-graphics-primitives/nerf_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct NerfDataset {
std::vector<TrainingImageMetadata> metadata;
GPUMemory<TrainingImageMetadata> metadata_gpu;

void update_metadata(int first = 0, int last = -1);
void update_metadata(int first = 0, int last = -1, bool in_cpu_ram = false);

std::vector<TrainingXForm> xforms;
std::vector<std::string> paths;
Expand Down Expand Up @@ -86,7 +86,7 @@ struct NerfDataset {
return (has_light_dirs ? 3u : 0u) + n_extra_learnable_dims;
}

void set_training_image(int frame_idx, const ivec2& image_resolution, const void* pixels, const void* depth_pixels, float depth_scale, bool image_data_on_gpu, EImageDataType image_type, EDepthDataType depth_type, float sharpen_amount = 0.f, bool white_transparent = false, bool black_transparent = false, uint32_t mask_color = 0, const Ray *rays = nullptr);
void set_training_image(int frame_idx, const ivec2& image_resolution, const void* pixels, const void* depth_pixels, float depth_scale, bool image_data_on_gpu, EImageDataType image_type, EDepthDataType depth_type, float sharpen_amount = 0.f, bool white_transparent = false, bool black_transparent = false, uint32_t mask_color = 0, const Ray *rays = nullptr, bool in_cpu_ram = false);

vec3 nerf_direction_to_ngp(const vec3& nerf_dir) {
vec3 result = nerf_dir;
Expand Down Expand Up @@ -168,7 +168,7 @@ struct NerfDataset {
}
};

NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amount = 0.f);
NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amount = 0.f, bool in_cpu_ram = false);
NerfDataset create_empty_nerf_dataset(size_t n_images, int aabb_scale = 1, bool is_hdr = false);

}
1 change: 1 addition & 0 deletions include/neural-graphics-primitives/testbed.h
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ class Testbed {
NerfDataset dataset;
int n_images_for_training = 0; // how many images to train from, as a high watermark compared to the dataset size
int n_images_for_training_prev = 0; // how many images we saw last time we updated the density grid
bool dataset_in_cpu_ram = false;

struct ErrorMap {
GPUMemory<float> data;
Expand Down
2 changes: 2 additions & 0 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def parse_args():
parser.add_argument("--height", "--screenshot_h", type=int, default=0, help="Resolution height of GUI and screenshots.")

parser.add_argument("--gui", action="store_true", help="Run the testbed GUI interactively.")
parser.add_argument("--nerf_dataset_in_cpu_ram", action="store_true", help="Store entire training dataset in cpu ram, and only send parts to vram for training iteration.")
parser.add_argument("--train", action="store_true", help="If the GUI is enabled, controls whether training starts immediately.")
parser.add_argument("--n_steps", type=int, default=-1, help="Number of steps to train for before quitting.")
parser.add_argument("--second_window", action="store_true", help="Open a second window containing a copy of the main output.")
Expand All @@ -93,6 +94,7 @@ def get_scene(scene):

testbed = ngp.Testbed()
testbed.root_dir = ROOT_DIR
testbed.m_nerf.training.dataset_in_cpu_ram = args.nerf_dataset_in_cpu_ram

for file in args.files:
scene_info = get_scene(file)
Expand Down
8 changes: 8 additions & 0 deletions src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ int main_func(const std::vector<std::string>& arguments) {
{"no-train"},
};

Flag nerf_dataset_in_cpu_ram{
parser,
"NERF_DATASET_IN_CPU_RAM",
"Store the entire training dataset in CPU RAM, and use Nvidia Unified Memory for GPU to access data for training iterations.",
{"nerf-dataset-in-cpu-ram"},
};

ValueFlag<string> scene_flag{
parser,
"SCENE",
Expand Down Expand Up @@ -150,6 +157,7 @@ int main_func(const std::vector<std::string>& arguments) {

Testbed testbed;

testbed.m_nerf.training.dataset_in_cpu_ram = nerf_dataset_in_cpu_ram;
for (auto file : get(files)) {
testbed.load_file(file);
}
Expand Down
36 changes: 22 additions & 14 deletions src/nerf_loader.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,11 @@ bool read_focal_length(const nlohmann::json &json, vec2 &focal_length, const ive
return true;
}

NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amount) {
NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amount, bool in_cpu_ram) {
if (jsonpaths.empty()) {
throw std::runtime_error{"Cannot load NeRF data from an empty set of paths."};
}

tlog::info() << "Loading NeRF dataset from";

NerfDataset result{};

std::ifstream f{native_string(jsonpaths.front())};
Expand Down Expand Up @@ -727,26 +725,28 @@ NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amou
result.sharpness_data.enlarge( result.sharpness_resolution.x * result.sharpness_resolution.y * result.n_images );

// copy / convert images to the GPU
auto progress_to_gpu = tlog::progress(result.n_images);
tlog::info() << "Copying / converting images to GPU...";
for (uint32_t i = 0; i < result.n_images; ++i) {
const LoadedImageInfo& m = images[i];
result.set_training_image(i, m.res, m.pixels, m.depth_pixels, m.depth_scale * result.scale, m.image_data_on_gpu, m.image_type, EDepthDataType::UShort, sharpen_amount, m.white_transparent, m.black_transparent, m.mask_color, m.rays);
result.set_training_image(i, m.res, m.pixels, m.depth_pixels, m.depth_scale * result.scale, m.image_data_on_gpu, m.image_type, EDepthDataType::UShort, sharpen_amount, m.white_transparent, m.black_transparent, m.mask_color, m.rays, in_cpu_ram);
CUDA_CHECK_THROW(cudaDeviceSynchronize());
}
CUDA_CHECK_THROW(cudaDeviceSynchronize());
// free memory
for (uint32_t i = 0; i < result.n_images; ++i) {
// free memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall exactly why, but I believe there was a reason for the two loops to be separate. I think the underlying memory might be aliased in some cases -- pleave revert. Putting the progress bar in the first loop likely matches current behavior closely enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heads up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't actually put the progress into the first loop. Putting it in the second is somewhat meaningless -- calling free is pretty much free.

if (images[i].image_data_on_gpu) {
CUDA_CHECK_THROW(cudaFree(images[i].pixels));
} else {
free(images[i].pixels);
}
free(images[i].rays);
free(images[i].depth_pixels);
progress_to_gpu.update(i);
}
CUDA_CHECK_THROW(cudaDeviceSynchronize());
tlog::success() << "Copied / converted " << images.size() << " images to GPU after " << tlog::durationToString(progress_to_gpu.duration());
return result;
}

void NerfDataset::set_training_image(int frame_idx, const ivec2& image_resolution, const void* pixels, const void* depth_pixels, float depth_scale, bool image_data_on_gpu, EImageDataType image_type, EDepthDataType depth_type, float sharpen_amount, bool white_transparent, bool black_transparent, uint32_t mask_color, const Ray *rays) {
void NerfDataset::set_training_image(int frame_idx, const ivec2& image_resolution, const void* pixels, const void* depth_pixels, float depth_scale, bool image_data_on_gpu, EImageDataType image_type, EDepthDataType depth_type, float sharpen_amount, bool white_transparent, bool black_transparent, uint32_t mask_color, const Ray *rays, bool in_cpu_ram) {
if (frame_idx < 0 || frame_idx >= n_images) {
throw std::runtime_error{"NerfDataset::set_training_image: invalid frame index"};
}
Expand All @@ -772,8 +772,13 @@ void NerfDataset::set_training_image(int frame_idx, const ivec2& image_resolutio
}

// copy or convert the pixels
pixelmemory[frame_idx].resize(img_size * image_type_size(image_type));
void* dst = pixelmemory[frame_idx].data();
size_t total_image_mem_size = img_size * image_type_size(image_type);
void* dst;
pixelmemory[frame_idx] = GPUMemory<uint8_t>(total_image_mem_size, in_cpu_ram);
dst = pixelmemory[frame_idx].data();
if (in_cpu_ram) {
CUDA_CHECK_THROW(cudaMemAdvise(dst, pixelmemory[frame_idx].get_bytes(), cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId));
}

switch (image_type) {
default: throw std::runtime_error{"unknown image type in set_training_image"};
Expand Down Expand Up @@ -846,10 +851,10 @@ void NerfDataset::set_training_image(int frame_idx, const ivec2& image_resolutio
raymemory[frame_idx].free_memory();
}
metadata[frame_idx].rays = raymemory[frame_idx].data();
update_metadata(frame_idx, frame_idx + 1);
update_metadata(frame_idx, frame_idx + 1, in_cpu_ram);
}

void NerfDataset::update_metadata(int first, int last) {
void NerfDataset::update_metadata(int first, int last, bool in_cpu_ram) {
if (last < 0) {
last = n_images;
}
Expand All @@ -864,7 +869,10 @@ void NerfDataset::update_metadata(int first, int last) {
}

metadata_gpu.enlarge(last);
CUDA_CHECK_THROW(cudaMemcpy(metadata_gpu.data() + first, metadata.data() + first, n * sizeof(TrainingImageMetadata), cudaMemcpyHostToDevice));
if (!in_cpu_ram) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong to me. The metadata is still stored on the GPU -- and it's so small that it wouldn't make use to offload to CPU ram anyway.

size_t total_size = n * sizeof(TrainingImageMetadata);
CUDA_CHECK_THROW(cudaMemcpy(metadata_gpu.data() + first, metadata.data() + first, total_size, cudaMemcpyHostToDevice));
}
}

}
1 change: 1 addition & 0 deletions src/python_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ PYBIND11_MODULE(pyngp, m) {
py::class_<Testbed::Nerf::Training>(nerf, "Training")
.def_readwrite("random_bg_color", &Testbed::Nerf::Training::random_bg_color)
.def_readwrite("n_images_for_training", &Testbed::Nerf::Training::n_images_for_training)
.def_readwrite("dataset_in_cpu_ram", &Testbed::Nerf::Training::dataset_in_cpu_ram)
.def_readwrite("linear_colors", &Testbed::Nerf::Training::linear_colors)
.def_readwrite("loss_type", &Testbed::Nerf::Training::loss_type)
.def_readwrite("depth_loss_type", &Testbed::Nerf::Training::depth_loss_type)
Expand Down
6 changes: 6 additions & 0 deletions src/testbed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,19 @@ void Testbed::set_mode(ETestbedMode mode) {
return;
}

// Temporarily store the settings that shouldn't reset
bool tmp_dataset_in_cpu_ram = m_nerf.training.dataset_in_cpu_ram;

// Reset mode-specific members
m_image = {};
m_mesh = {};
m_nerf = {};
m_sdf = {};
m_volume = {};

// Restore the tmp values after reset
m_nerf.training.dataset_in_cpu_ram = tmp_dataset_in_cpu_ram;

// Kill training-related things
m_encoding = {};
m_loss = {};
Expand Down
2 changes: 1 addition & 1 deletion src/testbed_nerf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2459,7 +2459,7 @@ void Testbed::load_nerf(const fs::path& data_path) {

const auto prev_aabb_scale = m_nerf.training.dataset.aabb_scale;

m_nerf.training.dataset = ngp::load_nerf(json_paths, m_nerf.sharpen);
m_nerf.training.dataset = ngp::load_nerf(json_paths, m_nerf.sharpen, m_nerf.training.dataset_in_cpu_ram);

// Check if the NeRF network has been previously configured.
// If it has not, don't reset it.
Expand Down