diff --git a/src/matmul.jl b/src/matmul.jl index d3eabfda..8593aecb 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -317,7 +317,7 @@ end BlasFlag.SYRK elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') BlasFlag.HERK - else isntc + else BlasFlag.GEMM end else @@ -499,7 +499,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β) return false end -# THE one big BLAS dispatch. This is split into two methods to improve latency +# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} mA, nA = lapack_size(tA, A) @@ -511,6 +511,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) return C end + +function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal} + gemm_wrapper!(C, tA, tB, A, B, α, β) +end + Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK}) if A === B tA_uc = uppercase(tA) # potentially strip a WrapperChar @@ -657,14 +663,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{true}) where {T<:BlasReal} - gemm_wrapper!(C, tA, tB, A, B, α, β) -end -Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal} - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) -end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} =