-
Notifications
You must be signed in to change notification settings - Fork 42
Issue/74 基于InfiniCore::nn::module适配Llama模型 #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3553814 to
294383d
Compare
| # Check if forward method is available | ||
| if hasattr(infinilm_model._model, 'forward'): | ||
| # Call forward method | ||
| infini_logits = infinilm_model._model.forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infinilm_model属于LlamaForCausalLM类。LlamaForCausalLM类中应该提供一个重载()的函数,通过infinilm_model( infini_input_ids, infini_position_ids, None )去调用。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在LlamaForCausalLM提供一个类似这样的函数,
def call(self, infini_input_ids, infini_position_ids, kv_cache=None):
return self._model.forward( infini_input_ids, infini_position_ids, kv_cache)
| elif not isinstance(config, LlamaConfig): | ||
| config = LlamaConfig(**config) | ||
|
|
||
| if device is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device() 类是不是已经删除了
| return self._cpp_config | ||
|
|
||
|
|
||
| class LlamaModel(infinicore.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class LlamaModel 并没有被使用到,可以删了
fae0713 to
f94518b
Compare
python/infinilm/generation/utils.py
Outdated
|
|
||
| eos_token_id = config.eos_token_id | ||
| # eos_token_id = config.eos_token_id | ||
| eos_token_id = 128001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eos_token_id = config.eos_token_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eos_token_id 不是确定数字,是从config.json中得到的, 不同模型可能不一样
python/infinilm/generation/utils.py
Outdated
| # 处理输出 | ||
| # -------------------------------------------------------------------------- # | ||
| token_scores = logits | ||
| seq_l = logits.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑写到 c++中, python端看到的 已经是 last_token了
| return self._cpp_config | ||
|
|
||
|
|
||
| class LlamaModel(infinicore.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class LlamaModel 定义了,但没有使用到. 可以删掉这个类
| @@ -1,15 +1,224 @@ | |||
| from ....generation.utils import GenerationMixin | |||
| import infinicore | |||
| from infinicore.device import device as Device | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不使用infinicore.device, 而是重命名为 Device. 还以为是两种类型呢
examples/llama.py
Outdated
|
|
||
| if __name__ == "__main__": | ||
| if False: | ||
| model_path = "/var/qy_home/zenghua/.cache/modelscope/hub/models/LLM-Research/Llama-3.2-1B-Instruct" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
临时的测试代码,需要删掉
|
|
||
| infini_device = infinicore.device(device_str, 0) | ||
| infini_dtype = infinicore.bfloat16 | ||
| infini_dtype = infinicore.float32 if backend == "cpp" else infinicore.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpp默认用float32,那么llama的cpp支持其他fp16或bf16吗
| : cache_position(0), max_capacity(0), initialized(false), | ||
| // Create empty placeholder tensors (will be replaced on first use) | ||
| k_cache(infinicore::Tensor::empty({1, 1, 1}, infinicore::DataType::F32, | ||
| infinicore::Device(infinicore::Device::Type::CPU, 0))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样写的话, 是不是只能跑Float32
|
|
||
| // Rotary Position Embeddings (RoPE) | ||
| INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
每一层都有一个rope对象了
csrc/models/llama/pybind11_llama.cc
Outdated
| @@ -0,0 +1,4 @@ | |||
| #include "pybind11_llama.hpp" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件中没有内容
csrc/models/llama/pybind11_llama.hpp
Outdated
| } else { | ||
| throw std::runtime_error("Invalid KV cache type. Expected LlamaKVCache or None."); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嵌套太多了
csrc/models/llama/pybind11_llama.hpp
Outdated
| // Try to cast to LlamaKVCache shared_ptr | ||
| try { | ||
| auto cache = item.cast<std::shared_ptr<LlamaKVCache>>(); | ||
| kv_caches_vec.push_back(cache.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嵌套太多了
examples/llama.py
Outdated
| _dec.Fuse(), | ||
| ] | ||
| ) | ||
| # if "llama" == config.model_type: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
恢复 if "llama" == config.model_type:
|
有4个commit 信息,可以整合为1个 |
|
|
||
| } | ||
|
|
||
| infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写的太复杂,理论不应该比python版本长多少
| infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, | ||
| const infinicore::Tensor &position_ids, | ||
| void *kv_cache, | ||
| const HookRegistry *hook_registry, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hook在哪用了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调试的时候用的,已经提交的代码删了
| // For batch=1 (common in inference), reshape to [n_q_head, seq_len, head_dim] | ||
| // Note: For batch > 1, this would need to be handled differently | ||
| // Make contiguous before final view since permute can make tensor non-contiguous | ||
| auto q_permuted_cont = q_permuted->contiguous(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这么多contiguous,性能不要了?
| return output; | ||
| } | ||
|
|
||
| infinicore::Tensor LlamaAttention::project_q(const infinicore::Tensor &hidden_states) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要这种一两行的函数
| * Stores key and value caches with shape [n_kv_head, capacity, head_dim] | ||
| * Similar to DynamicLayer in Python cache_utils.py | ||
| */ | ||
| struct LlamaKVCache { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache应该是通用的
| /** | ||
| * @brief Hook registry for managing hooks | ||
| */ | ||
| class HookRegistry { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有啥用啊?
| } | ||
|
|
||
| auto scaling_broadcast = scaling_value_tensor->as_strided(attn_weight->shape(), {0, 0, 0}); | ||
| attn_weight = infinicore::op::mul(attn_weight, scaling_broadcast); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个scale应该直接作为gemm的alpha传进入
| @@ -0,0 +1,184 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个东西可以作为nn::module的通用基建放在infinicore里
| @@ -3,7 +3,8 @@ | |||
| #include <pybind11/pybind11.h> | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照pybind11/models/llama.hpp 命名吧
Signed-off-by: Ceng23333 <[email protected]>
| * - LlamaDecoderLayer: Single transformer decoder layer | ||
| * - LlamaModel: Core transformer model (without LM head) | ||
| * - LlamaForCausalLM: Complete model with language modeling head | ||
| * - HookRegistry: Hook system for capturing intermediate values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有comment没删掉
#74
端到端验证截图
