Skip to content

Commit

Permalink
[Model][Phi3-Small] Remove scipy from blocksparse_attention (#6343)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jul 12, 2024
1 parent adf32e0 commit d59eb98
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions vllm/attention/ops/blocksparse_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,35 @@

from functools import lru_cache

import numpy as np
import torch
import triton

try:
from scipy import sparse
except ImportError as err:
raise ImportError("Please install scipy via "
"`pip install scipy` to use "
"BlockSparseAttention in "
"models such as Phi-3.") from err

class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""

def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")

self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]

for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))

self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)


def dense_to_crow_col(x: torch.Tensor):
Expand All @@ -26,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)
Expand Down

0 comments on commit d59eb98

Please sign in to comment.