|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "infinicore/tensor.hpp" |
| 4 | +#include <functional> |
| 5 | +#include <string> |
| 6 | +#include <memory> |
| 7 | +#include <unordered_map> |
| 8 | + |
| 9 | +namespace infinilm::models::debug_utils { |
| 10 | + |
| 11 | +// TODO: move to InfiniCore as common utils in future work |
| 12 | + |
| 13 | +/** |
| 14 | + * @brief Hook callback type for capturing intermediate values (DEBUG ONLY) |
| 15 | + * |
| 16 | + * Hook functions are called with: |
| 17 | + * - name: Identifier for the intermediate value (e.g., "layer0_q_after_proj") |
| 18 | + * - tensor: The intermediate tensor value |
| 19 | + * - layer_idx: Layer index (for layer-specific hooks, -1 if not applicable) |
| 20 | + * |
| 21 | + * NOTE: This is a debug utility. Do not use in production code. |
| 22 | + */ |
| 23 | +using HookCallback = std::function<void(const std::string &name, const infinicore::Tensor &tensor, int layer_idx)>; |
| 24 | + |
| 25 | +/** |
| 26 | + * @brief Hook registry for managing hooks (DEBUG ONLY) |
| 27 | + * |
| 28 | + * NOTE: This is a debug utility for capturing intermediate tensor values |
| 29 | + * during model execution. Do not use in production code. |
| 30 | + */ |
| 31 | +class HookRegistry { |
| 32 | +public: |
| 33 | + /** |
| 34 | + * @brief Register a hook callback |
| 35 | + * |
| 36 | + * @param name Hook name (can be pattern like "layer0_*" or specific name) |
| 37 | + * @param callback Hook callback function |
| 38 | + */ |
| 39 | + void register_hook(const std::string &name, HookCallback callback); |
| 40 | + |
| 41 | + /** |
| 42 | + * @brief Call hook if registered |
| 43 | + * |
| 44 | + * @param name Full hook name |
| 45 | + * @param tensor Tensor to pass to hook |
| 46 | + * @param layer_idx Layer index (-1 if not applicable) |
| 47 | + */ |
| 48 | + void call_hook(const std::string &name, const infinicore::Tensor &tensor, int layer_idx = -1) const; |
| 49 | + |
| 50 | + /** |
| 51 | + * @brief Clear all hooks |
| 52 | + */ |
| 53 | + void clear(); |
| 54 | + |
| 55 | + /** |
| 56 | + * @brief Check if any hooks are registered |
| 57 | + */ |
| 58 | + bool has_hooks() const { return !hooks_.empty(); } |
| 59 | + |
| 60 | +private: |
| 61 | + std::unordered_map<std::string, HookCallback> hooks_; |
| 62 | +}; |
| 63 | + |
| 64 | +/** |
| 65 | + * @brief Macro to simplify hook registration (DEBUG ONLY) |
| 66 | + * |
| 67 | + * Usage: REGISTER_HOOK(registry, "hook_name", callback) |
| 68 | + */ |
| 69 | +#define REGISTER_HOOK(registry, name, callback) \ |
| 70 | + (registry)->register_hook(name, callback) |
| 71 | + |
| 72 | +/** |
| 73 | + * @brief Macro to simplify hook calls with automatic null and has_hooks checks (DEBUG ONLY) |
| 74 | + * |
| 75 | + * Usage: CALL_HOOK(registry, "hook_name", tensor) |
| 76 | + * Note: layer_idx defaults to -1 |
| 77 | + */ |
| 78 | +#define CALL_HOOK(registry, name, tensor) \ |
| 79 | + do { \ |
| 80 | + if ((registry) && (registry)->has_hooks()) { \ |
| 81 | + (registry)->call_hook(name, tensor, -1); \ |
| 82 | + } \ |
| 83 | + } while (0) |
| 84 | + |
| 85 | +/** |
| 86 | + * @brief Macro to simplify hook calls with explicit layer index (DEBUG ONLY) |
| 87 | + * |
| 88 | + * Usage: CALL_HOOK_LAYER(registry, "hook_name", tensor, layer_idx) |
| 89 | + */ |
| 90 | +#define CALL_HOOK_LAYER(registry, name, tensor, layer_idx) \ |
| 91 | + do { \ |
| 92 | + if ((registry) && (registry)->has_hooks()) { \ |
| 93 | + (registry)->call_hook(name, tensor, layer_idx); \ |
| 94 | + } \ |
| 95 | + } while (0) |
| 96 | + |
| 97 | +/** |
| 98 | + * @brief Macros to simplify hook_registry and hook_prefix management in model classes |
| 99 | + */ |
| 100 | + |
| 101 | +// Declare hook_registry and hook_prefix member variables |
| 102 | +#define HOOK_REGISTRY_MEMBER() \ |
| 103 | + std::shared_ptr<debug_utils::HookRegistry> hook_registry_; \ |
| 104 | + std::string hook_prefix_; |
| 105 | + |
| 106 | +// Set hook_registry and hook_prefix (no forwarding to submodules) |
| 107 | +#define SET_HOOK_REGISTRY_SIMPLE() \ |
| 108 | + void set_hook_registry(const std::shared_ptr<debug_utils::HookRegistry> &hook_registry, const std::string &hook_prefix = "") { \ |
| 109 | + hook_registry_ = hook_registry; \ |
| 110 | + hook_prefix_ = hook_prefix; \ |
| 111 | + } |
| 112 | + |
| 113 | +// Helper macro to build incremental hook prefix |
| 114 | +#define BUILD_HOOK_PREFIX(prefix, name) \ |
| 115 | + (prefix.empty() ? std::string(name) : prefix + "_" + std::string(name)) |
| 116 | + |
| 117 | +// Set hook_registry and hook_prefix and forward to one or more submodules |
| 118 | +// Usage: SET_HOOK_REGISTRY(submodule1) or SET_HOOK_REGISTRY(submodule1, submodule2) |
| 119 | +// The hook_prefix will be incremented for each submodule (e.g., "layer0" -> "layer0_attention") |
| 120 | +// Note: Currently supports up to 2 submodules. For more, extend the pattern below. |
| 121 | +#define SET_HOOK_REGISTRY(...) \ |
| 122 | + SET_HOOK_REGISTRY_IMPL(__VA_ARGS__) |
| 123 | + |
| 124 | +// Helper to handle variable number of arguments using a reliable pattern |
| 125 | +#define SET_HOOK_REGISTRY_IMPL(...) \ |
| 126 | + SET_HOOK_REGISTRY_GET_NTH(__VA_ARGS__, SET_HOOK_REGISTRY_2, SET_HOOK_REGISTRY_1, SET_HOOK_REGISTRY_0,)(__VA_ARGS__) |
| 127 | + |
| 128 | +// Get the selector based on argument count |
| 129 | +// Pattern: when we have N args, the (N+1)th parameter from the end is the selector |
| 130 | +// For 0 args: _1=SET_HOOK_REGISTRY_2, _2=SET_HOOK_REGISTRY_1, _3=SET_HOOK_REGISTRY_0, N=(empty) → need to use _3 |
| 131 | +// For 1 arg: _1=arg, _2=SET_HOOK_REGISTRY_2, _3=SET_HOOK_REGISTRY_1, N=SET_HOOK_REGISTRY_0 → wrong, need _3 |
| 132 | +// For 2 args: _1=arg1, _2=arg2, _3=SET_HOOK_REGISTRY_2, N=SET_HOOK_REGISTRY_1 → wrong, need _3 |
| 133 | + |
| 134 | +// Use _3 as the selector (it's in the right position for all cases) |
| 135 | +#define SET_HOOK_REGISTRY_GET_NTH(_1, _2, _3, N, ...) _3 |
| 136 | + |
| 137 | +// Implementation for 0 args (shouldn't be used, but handle gracefully) |
| 138 | +#define SET_HOOK_REGISTRY_0() \ |
| 139 | + void set_hook_registry(const std::shared_ptr<debug_utils::HookRegistry> &hook_registry, const std::string &hook_prefix = "") { \ |
| 140 | + hook_registry_ = hook_registry; \ |
| 141 | + hook_prefix_ = hook_prefix; \ |
| 142 | + } |
| 143 | + |
| 144 | +// Implementation for 1 arg |
| 145 | +#define SET_HOOK_REGISTRY_1(submodule) \ |
| 146 | + void set_hook_registry(const std::shared_ptr<debug_utils::HookRegistry> &hook_registry, const std::string &hook_prefix = "") { \ |
| 147 | + hook_registry_ = hook_registry; \ |
| 148 | + hook_prefix_ = hook_prefix; \ |
| 149 | + if (submodule##_) { \ |
| 150 | + std::string submodule_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule); \ |
| 151 | + submodule##_->set_hook_registry(hook_registry, submodule_prefix); \ |
| 152 | + } \ |
| 153 | + } |
| 154 | + |
| 155 | +// Implementation for 2 args |
| 156 | +#define SET_HOOK_REGISTRY_2(submodule1, submodule2) \ |
| 157 | + void set_hook_registry(const std::shared_ptr<debug_utils::HookRegistry> &hook_registry, const std::string &hook_prefix = "") { \ |
| 158 | + hook_registry_ = hook_registry; \ |
| 159 | + hook_prefix_ = hook_prefix; \ |
| 160 | + if (submodule1##_) { \ |
| 161 | + std::string submodule1_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule1); \ |
| 162 | + submodule1##_->set_hook_registry(hook_registry, submodule1_prefix); \ |
| 163 | + } \ |
| 164 | + if (submodule2##_) { \ |
| 165 | + std::string submodule2_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule2); \ |
| 166 | + submodule2##_->set_hook_registry(hook_registry, submodule2_prefix); \ |
| 167 | + } \ |
| 168 | + } |
| 169 | + |
| 170 | +// Set hook_registry and hook_prefix for a vector of submodules |
| 171 | +// For vectors, the prefix is incremented with an index (e.g., "layer0", "layer1", ...) |
| 172 | +// If parent has a prefix, it becomes "parent_layer0", "parent_layer1", etc. |
| 173 | +#define SET_HOOK_REGISTRY_VEC(vec_name) \ |
| 174 | + void set_hook_registry(const std::shared_ptr<debug_utils::HookRegistry> &hook_registry, const std::string &hook_prefix = "") { \ |
| 175 | + hook_registry_ = hook_registry; \ |
| 176 | + hook_prefix_ = hook_prefix; \ |
| 177 | + for (size_t i = 0; i < vec_name##_.size(); ++i) { \ |
| 178 | + if (vec_name##_[i]) { \ |
| 179 | + std::string layer_name = "layer" + std::to_string(i); \ |
| 180 | + std::string item_prefix = BUILD_HOOK_PREFIX(hook_prefix, layer_name); \ |
| 181 | + vec_name##_[i]->set_hook_registry(hook_registry, item_prefix); \ |
| 182 | + } \ |
| 183 | + } \ |
| 184 | + } |
| 185 | + |
| 186 | +} // namespace infinilm::models::debug_utils |
0 commit comments