@@ -137,11 +137,39 @@ function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
137137 s, t, eweight = coo
138138 eweight = isnothing (eweight) ? fill! (similar (s, T), 1 ) : eweight
139139 num_nodes = isnothing (num_nodes) ? max (maximum (s), maximum (t)) : num_nodes
140- A = sparse (s, t, eweight, num_nodes, num_nodes)
140+ A = _sparse (s, t, eweight, num_nodes, num_nodes)
141141 num_edges = length (s)
142142 return A, num_nodes, num_edges
143143end
144144
145+ _sparse (s, t, eweight, n, m) = sparse (s, t, eweight, n, m)
146+
147+ function _sparse (I:: CuVector , J:: CuVector , V:: CuVector , m, n)
148+ spcoo = CuSparseMatrixCOO {Float32, Int32} (Int32 .(I), Int32 .(J), Float32 .(V), (m, n))
149+ return CuSparseMatrixCSR (spcoo)
150+ end
151+
152+ # function _sparse(I::CuVector, J::CuVector, V::CuVector, m, n; fmt=:csr)
153+ # # Tv = Int32
154+ # spcoo = CuSparseMatrixCOO{Float32, Int32}(Int32.(I), Int32.(J), Float32.(V), (m, n))
155+ # if fmt == :csc
156+ # return CuSparseMatrixCSC(spcoo)
157+ # elseif fmt == :csr
158+ # return CuSparseMatrixCSR(spcoo)
159+ # elseif fmt == :coo
160+ # return spcoo
161+ # else
162+ # error("Format :$fmt not available, use :csc, :csr, or :coo.")
163+ # end
164+ # end
165+
166+
167+ # Workaround for https://github.com/JuliaGPU/CUDA.jl/issues/1113#issuecomment-955759875
168+ function Base.:* (A:: CuMatrix , B:: CuSparseMatrixCSR )
169+ @assert size (A, 2 ) == size (B, 1 )
170+ return CuMatrix ((B' * A' )' )
171+ end
172+
145173
146174@non_differentiable to_coo (x... )
147175@non_differentiable to_dense (x... )
0 commit comments