@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " linear_impl .h"
16+ #include " linear .h"
1717
1818#include < glog/logging.h>
1919#include < torch/torch.h>
@@ -82,9 +82,8 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
8282
8383torch::Tensor ColumnParallelLinearImpl::forward (torch::Tensor input) {
8484 input = input.to (device_);
85- auto bias = (bias_.defined () && rank_ == 0 )
86- ? std::optional<torch::Tensor>(bias_)
87- : std::nullopt ;
85+ auto bias =
86+ bias_.defined () ? std::optional<torch::Tensor>(bias_) : std::nullopt ;
8887
8988 torch::Tensor output;
9089
@@ -148,8 +147,8 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {
148147
149148// load the weight from the checkpoint
150149void ColumnParallelLinearImpl::load_state_dict (const StateDict& state_dict) {
151- const auto rank = rank_;
152- const auto world_size = world_size_;
150+ const int64_t rank = rank_;
151+ const int64_t world_size = world_size_;
153152
154153 // load and merge the weights on dim 0
155154 // If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
@@ -172,8 +171,8 @@ void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
172171void ColumnParallelLinearImpl::load_state_dict (
173172 const StateDict& state_dict,
174173 const std::vector<std::string>& prefixes) {
175- const auto rank = rank_;
176- const auto world_size = world_size_;
174+ const int64_t rank = rank_;
175+ const int64_t world_size = world_size_;
177176
178177 // load and merge the weights on dim 0
179178 // If quant_args_ indicates SmoothQuant, load qweight
@@ -192,7 +191,6 @@ void ColumnParallelLinearImpl::load_state_dict(
192191 break ;
193192 }
194193 }
195-
196194 LOAD_FUSED_WEIGHT (qweight, 0 );
197195 LOAD_FUSED_WEIGHT (per_channel_scale, 0 );
198196 } else {
@@ -223,36 +221,32 @@ QKVParallelLinearImpl::QKVParallelLinearImpl(
223221 parallel_args_(parallel_args),
224222 options_(options),
225223 device_(options.device()) {
226- const int32_t QKV_CNT = 3 ;
227224 rank_ = parallel_args_.tp_group_ ->rank ();
228225 world_size_ = parallel_args_.tp_group_ ->world_size ();
229226 const int64_t out_features_per_partition =
230227 (num_heads + 2 * num_kv_heads) * head_size;
231228 // Note: torch.nn.functional.linear performs XA^T + b and as a result
232229 // we allocate the transpose.
233- qkv_weight_ = register_parameter (
230+ weight_ = register_parameter (
234231 " weight" ,
235232 torch::empty ({out_features_per_partition, hidden_size}, options),
236233 /* requires_grad=*/ false );
237- qkv_weight_list_.resize (QKV_CNT);
238234
239235 if (bias) {
240- qkv_bias_ =
236+ bias_ =
241237 register_parameter (" bias" ,
242238 torch::empty ({out_features_per_partition}, options),
243239 /* requires_grad=*/ false );
244- qkv_bias_list_.resize (QKV_CNT);
245240 }
246241}
247242
248243torch::Tensor QKVParallelLinearImpl::forward (torch::Tensor input) {
249244 input = input.to (device_);
250- auto bias = (qkv_bias_.defined () && rank_ == 0 )
251- ? std::optional<torch::Tensor>(qkv_bias_)
252- : std::nullopt ;
245+ auto bias =
246+ bias_.defined () ? std::optional<torch::Tensor>(bias_) : std::nullopt ;
253247 xllm::kernel::MatmulParams matmul_params;
254248 matmul_params.a = input;
255- matmul_params.b = qkv_weight_ ;
249+ matmul_params.b = weight_ ;
256250 matmul_params.bias = bias;
257251
258252 auto output = xllm::kernel::matmul (matmul_params);
@@ -262,46 +256,13 @@ torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) {
262256 return output;
263257}
264258
265- bool QKVParallelLinearImpl::load_qkv_weight (const StateDict& state_dict,
266- int32_t index) {
267- if (qkv_weight_list_[index].defined () || state_dict.size () == 0 ) {
268- return false ;
269- }
270- DEFINE_WEIGHT (weight);
271- int64_t out_feature = num_heads_ * head_size_;
272- int64_t rank = rank_;
273- int64_t world_size = world_size_;
274- if (index > 0 ) {
275- rank = rank_ / num_kv_head_replicas_;
276- world_size = world_size_ / num_kv_head_replicas_;
277- out_feature = num_kv_heads_ * head_size_;
278- }
279- weight_ = torch::empty ({out_feature, hidden_size_}, options_);
280- LOAD_SHARDED_WEIGHT (weight, 0 );
281- if (weight_is_loaded_) {
282- qkv_weight_list_[index] = weight_.clone ();
283- }
284- return weight_is_loaded_;
285- }
286-
287259void QKVParallelLinearImpl::load_state_dict (const StateDict& state_dict) {
288260 std::vector<std::string> prefixes = {" q_proj." , " k_proj." , " v_proj." };
289- if (!qkv_weight_is_loaded_) {
290- bool all_loaded = true ;
291- for (size_t i = 0 ; i < prefixes.size (); ++i) {
292- all_loaded =
293- all_loaded &&
294- load_qkv_weight (state_dict.get_dict_with_prefix (prefixes[i]), i);
295- }
296- if (all_loaded) {
297- const auto merged_weight = torch::cat (qkv_weight_list_, /* dim=*/ 0 );
298- CHECK_EQ (qkv_weight_.sizes (), merged_weight.sizes ())
299- << " weight size mismatch" ;
300- qkv_weight_.copy_ (merged_weight);
301- // release the memory for weight_list
302- qkv_weight_list_.clear ();
303- qkv_weight_is_loaded_ = true ;
304- }
261+ const int64_t rank = rank_;
262+ const int64_t world_size = world_size_;
263+ LOAD_QKV_WEIGHT (weight, 0 , num_kv_head_replicas_);
264+ if (bias_.defined ()) {
265+ LOAD_QKV_WEIGHT (bias, 0 , num_kv_head_replicas_);
305266 }
306267}
307268
@@ -424,8 +385,8 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) {
424385
425386// load the weight from the checkpoint
426387void RowParallelLinearImpl::load_state_dict (const StateDict& state_dict) {
427- const auto rank = rank_;
428- const auto world_size = world_size_;
388+ const int64_t rank = rank_;
389+ const int64_t world_size = world_size_;
429390
430391 // If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
431392 // normal weight.
@@ -462,7 +423,6 @@ ReplicatedLinearImpl::ReplicatedLinearImpl(
462423}
463424
464425torch::Tensor ReplicatedLinearImpl::forward (torch::Tensor input) {
465- namespace F = torch::nn::functional;
466426 auto bias =
467427 bias_.defined () ? std::optional<torch::Tensor>(bias_) : std::nullopt ;
468428 xllm::kernel::MatmulParams matmul_params;
0 commit comments