Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions ucm/sparse/gsa/gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def set_block_hashes(self, token_ids):
hash_value = self.request_hasher(
(parent_block_hash_value, curr_block_token_ids_tuple)
)
self.block_hashes.append(str(hash_value))
parent_block_hash_value = hash_value

if self.rank != 0 and not self.use_mla:
Expand Down Expand Up @@ -421,7 +422,7 @@ def cal_topk(self, intermediate_q, current_layer_id):
dot_product_weights.masked_fill_(self.exclude_mask == 1, float("-inf"))
selected_block_nums = self.topk_len_list[0]
_, top_indices = torch.topk(
dot_product_weights, selected_block_nums, dim=-1, sorted=False
dot_product_weights, selected_block_nums, dim=-1, sorted=True
)
self.topk_caches[current_layer_id][self.cal_topk_id] = top_indices

Expand Down Expand Up @@ -582,7 +583,9 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None:
if not self.use_mla:
self.gsa_q_cache[current_layer_id][: len(ids)].copy_(query[ids])
else:
self.gsa_q_cache[current_layer_id][self.decode_index].copy_(query)
self.gsa_q_cache[current_layer_id][: len(self.decode_index)].copy_(
query
)
is_cal_kpre = len(self.model_input["calc_block_table"]) > 0
self.gsa_offload_ops.add_copy_req(
is_cal_kpre, current_layer_id, ids, self.gsa_q_cache[current_layer_id]
Expand Down Expand Up @@ -656,7 +659,7 @@ def attention_begin(
else:
attn_metadata.block_tables[
: len(self.prefetch_engine.req_ids_bs)
].copy_(self.model_input["block_tables_mp"][current_layer_id])
] = self.model_input["block_tables_mp"][current_layer_id]
attn_metadata.seq_lens.copy_(
self.model_input["gsa_seq_len"][current_layer_id]
)
Expand All @@ -670,9 +673,7 @@ def attention_begin(
current_layer_id
][self.decode_index]
else:
attn_metadata.decode.block_table[
: len(self.prefetch_engine.req_ids_bs)
].copy_(
attn_metadata.decode.block_table[: len(self.decode_index)] = (
self.model_input["block_tables_mp"][current_layer_id][
self.decode_index
]
Expand Down Expand Up @@ -937,9 +938,9 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches):
fn = getattr(self.connector, "load")
precision = self.element_size
if self.use_mla:
block_data_size = kv_caches[0].numel() * precision
else:
block_data_size = kv_caches[0][0].numel() * precision
else:
block_data_size = kv_caches[0][0][0].numel() * precision

offsets_k = []
key_src_tensors = []
Expand Down Expand Up @@ -1069,10 +1070,7 @@ def _start_topk_cal(self) -> None:
if req_meta.is_gsa():
cal_topk_id.append(req_meta.index_in_batch)
is_decode.append(True)
one_topk_len = (
gsa_config.compute_topk_len(len(req_meta.blocks))
+ gsa_config.num_prefetch_blocks
)
one_topk_len = gsa_config.compute_topk_len(len(req_meta.blocks))
topk_len_list.append(one_topk_len)
if CUDA_TOPK:
include_masks.append(req_meta.include_mask)
Expand Down
13 changes: 6 additions & 7 deletions ucm/sparse/gsa/prefetch/prefetch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ def model_input_deal(

if self.atb_gsa_enable:
block_table_index = torch.tensor(self.select_bs_index, device="cpu")
self.topk_len = (
gsa_config.compute_topk_len(self._get_max_block_len(gsa_metadata))
+ gsa_config.num_prefetch_blocks
self.topk_len = gsa_config.compute_topk_len(
self._get_max_block_len(gsa_metadata)
)
topk_buf_tmp = self.use_topk_caches[:, block_table_index, :]
topk_buf_tmp = topk_buf_tmp[:, :, : self.topk_len]
Expand Down Expand Up @@ -190,9 +189,8 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp):
)
self.topk_bs = []
for index, req_id in enumerate(self.req_ids_bs):
one_topk_len = (
gsa_config.compute_topk_len(len(gsa_metadata.gsa_stats[req_id].blocks))
+ gsa_config.num_prefetch_blocks
one_topk_len = gsa_config.compute_topk_len(
len(gsa_metadata.gsa_stats[req_id].blocks)
)
self.topk_bs.append(
[
Expand Down Expand Up @@ -536,7 +534,8 @@ def _set_req_stat(
def _get_max_block_len(self, gsa_metadata) -> int:
max_len = 0
for req_id in self.req_ids_bs:
max_len = max(max_len, len(gsa_metadata.gsa_stats[req_id].blocks))
if self.is_gsa_req_id[req_id]:
max_len = max(max_len, len(gsa_metadata.gsa_stats[req_id].blocks))
return max_len

def _no_gsa_input_deal(
Expand Down
17 changes: 6 additions & 11 deletions ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor&
mExtraTopkLen = extraTopkLen;
mLogger.log(LogLevel::INFO,
"GSAPrefetchEngineC Init mLayerNum %d mMaxBs %u, mUseMla %d, mHeadSzie %u, mTPSize "
"%u mBlockSize %u mHeadNum %u\n",
mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum);
"%u mBlockSize %u mHeadNum %u, mIsPythonLoad %d\n",
mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum,
mIsPythonLoad);
}

size_t GSAPrefetchEngineC::GetOffset(uint32_t layerID, bool isV)
Expand Down Expand Up @@ -343,16 +344,14 @@ void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo,
int blockID = mDocsTables[reqID][layerID][item];
hitBlocks.insert(blockID);
hitBlocksIdx.insert(std::make_pair(item, blockID));
if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { break; }
} else {
missIdxs.push_back(item);
}
}
oss << "------\n";
mLogger.log(LogLevel::DEBUG, oss.str().c_str());
oss.str("");
if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen &&
hitBlocks.size() != (topkLen - mExtraTopkLen)) {
if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen) {
mLogger.log(LogLevel::ERROR,
"|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: "
"%lu, miss size: %lu , topkLen: %d, not equal error\n",
Expand All @@ -368,7 +367,6 @@ void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo,
{
int layerID = oneBsInfo.layerID;
std::string reqID = oneBsInfo.reqID;
uint32_t topkLen = oneBsInfo.topkLen;
int bsIndex = oneBsInfo.bsIndex;

int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item<int>();
Expand All @@ -377,8 +375,7 @@ void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo,

uint32_t index = 0;
int oneFreeBlockIndex = 0;
while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size() &&
hitBlocks.size() < (topkLen - mExtraTopkLen)) {
while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size()) {
int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex];
if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) {
oneFreeBlockIndex += 1;
Expand Down Expand Up @@ -415,9 +412,7 @@ void GSAPrefetchEngineC::RunOneBsPrefetch(std::string reqID, int topkLen, int bs
oneBsInfo.bsIndex = bsIndex;
oneBsInfo.layerID = i;
GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs);
if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) {
RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs);
}
if (missIdxs.size() != 0) { RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); }
int successIndex = 0;
for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) {
mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second;
Expand Down