Skip to content

Commit

Permalink
fix multinomial sampling (#1228)
Browse files Browse the repository at this point in the history
* fix

* fix repe penal

---------

Co-authored-by: grimoire <[email protected]>
  • Loading branch information
grimoire and grimoire authored Mar 3, 2024
1 parent f0dabee commit 79ac87b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def _process_temperature(scores: torch.Tensor,
temperature: torch.Tensor,
inplace: bool = True):
"""process temperature."""
temperature = temperature.to(scores.dtype)
if not inplace:
scores = scores / temperature[:, None]
else:
Expand Down Expand Up @@ -42,6 +43,7 @@ def _process_repetition_penalty(scores: torch.Tensor,
inplace: bool = True):
"""process repetition penalty."""
score = torch.gather(scores, 1, input_ids)
penalty = penalty.to(score.dtype)
score = torch.where(score < 0, score * penalty[:, None],
score / penalty[:, None])
if not inplace:
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/kernels/multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,

samp = tl.rand(seed, offset)[:, None]
acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty)
output = tl.full((BLOCK, ), -1, dtype=tl.int64)
output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty)

for b_idx in range(0, num_tokens, BLOCK_N):
s_off = b_idx + n_off
Expand All @@ -31,8 +31,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,
s_off[None, :] * stride_st,
mask=s_mask,
other=0.0)
cum_scores = acc[:, None] + tl.cumsum(scores, 1)
acc += tl.sum(scores, 1)
cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype)
acc += tl.sum(scores, 1).to(acc.dtype)

pre_cum_scores = cum_scores - scores
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
Expand Down
9 changes: 8 additions & 1 deletion tests/pytorch/kernel/test_multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ def batch_size(self, select_ids):
yield len(select_ids)

@pytest.fixture
def scores(self, num_tokens, batch_size, select_ids):
def dtype(self, request):
yield request.param

@pytest.fixture
def scores(self, num_tokens, batch_size, select_ids, dtype):
ret = torch.zeros(batch_size, num_tokens).cuda()
batch_ids = torch.arange(batch_size).cuda()
ret[batch_ids, select_ids] = 1
ret = ret.to(dtype)
yield ret

@pytest.fixture
Expand All @@ -45,6 +50,8 @@ def gt(self, batch_size, select_ids, indices):
batch_ids = torch.arange(batch_size).cuda()
yield indices[batch_ids, select_ids]

@pytest.mark.parametrize('dtype',
[torch.float32, torch.half, torch.bfloat16])
@pytest.mark.parametrize(['num_tokens', 'select_ids'], [
(8, (4, 2) * 30),
(200, (50, 150)),
Expand Down

0 comments on commit 79ac87b

Please sign in to comment.