|
| 1 | +""" |
| 2 | + KernelTensorProduct |
| 3 | +
|
| 4 | +Tensor product of kernels. |
| 5 | +
|
| 6 | +# Definition |
| 7 | +
|
| 8 | +For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor |
| 9 | +product of kernels ``k_1, \\ldots, k_n`` is defined as |
| 10 | +```math |
| 11 | +k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i). |
| 12 | +``` |
| 13 | +
|
| 14 | +# Construction |
| 15 | +
|
| 16 | +The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor` |
| 17 | +operator or its alias `⊗` (can be typed by `\\otimes<tab>`). |
| 18 | +```jldoctest tensorproduct |
| 19 | +julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2); |
| 20 | +
|
| 21 | +julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2]) |
| 22 | +true |
| 23 | +``` |
| 24 | +
|
| 25 | +You can also specify a `KernelTensorProduct` by providing kernels as individual arguments |
| 26 | +or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or |
| 27 | +individual arguments guarantees that `KernelTensorProduct` is concretely typed but might |
| 28 | +lead to large compilation times if the number of kernels is large. |
| 29 | +```jldoctest tensorproduct |
| 30 | +julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2 |
| 31 | +true |
| 32 | +
|
| 33 | +julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2 |
| 34 | +true |
| 35 | +
|
| 36 | +julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2 |
| 37 | +true |
| 38 | +``` |
| 39 | +""" |
| 40 | +struct KernelTensorProduct{K} <: Kernel |
| 41 | + kernels::K |
| 42 | +end |
| 43 | + |
| 44 | +function KernelTensorProduct(kernel::Kernel, kernels::Kernel...) |
| 45 | + return KernelTensorProduct((kernel, kernels...)) |
| 46 | +end |
| 47 | + |
| 48 | +@functor KernelTensorProduct |
| 49 | + |
| 50 | +Base.length(kernel::KernelTensorProduct) = length(kernel.kernels) |
| 51 | + |
| 52 | +function (kernel::KernelTensorProduct)(x, y) |
| 53 | + if !(length(x) == length(y) == length(kernel)) |
| 54 | + throw(DimensionMismatch("number of kernels and number of features |
| 55 | +are not consistent")) |
| 56 | + end |
| 57 | + return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) |
| 58 | +end |
| 59 | + |
| 60 | +function validate_domain(k::KernelTensorProduct, x::AbstractVector) |
| 61 | + return dim(x) == length(k) || |
| 62 | + error("number of kernels and groups of features are not consistent") |
| 63 | +end |
| 64 | + |
| 65 | +# Utility for slicing up inputs. |
| 66 | +slices(x::AbstractVector{<:Real}) = (x,) |
| 67 | +slices(x::ColVecs) = eachrow(x.X) |
| 68 | +slices(x::RowVecs) = eachcol(x.X) |
| 69 | + |
| 70 | +function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector) |
| 71 | + validate_inplace_dims(K, x) |
| 72 | + validate_domain(k, x) |
| 73 | + |
| 74 | + kernels_and_inputs = zip(k.kernels, slices(x)) |
| 75 | + kernelmatrix!(K, first(kernels_and_inputs)...) |
| 76 | + for (k, xi) in Iterators.drop(kernels_and_inputs, 1) |
| 77 | + K .*= kernelmatrix(k, xi) |
| 78 | + end |
| 79 | + |
| 80 | + return K |
| 81 | +end |
| 82 | + |
| 83 | +function kernelmatrix!( |
| 84 | + K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector |
| 85 | +) |
| 86 | + validate_inplace_dims(K, x, y) |
| 87 | + validate_domain(k, x) |
| 88 | + |
| 89 | + kernels_and_inputs = zip(k.kernels, slices(x), slices(y)) |
| 90 | + kernelmatrix!(K, first(kernels_and_inputs)...) |
| 91 | + for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1) |
| 92 | + K .*= kernelmatrix(k, xi, yi) |
| 93 | + end |
| 94 | + |
| 95 | + return K |
| 96 | +end |
| 97 | + |
| 98 | +function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector) |
| 99 | + validate_inplace_dims(K, x) |
| 100 | + validate_domain(k, x) |
| 101 | + |
| 102 | + kernels_and_inputs = zip(k.kernels, slices(x)) |
| 103 | + kerneldiagmatrix!(K, first(kernels_and_inputs)...) |
| 104 | + for (k, xi) in Iterators.drop(kernels_and_inputs, 1) |
| 105 | + K .*= kerneldiagmatrix(k, xi) |
| 106 | + end |
| 107 | + |
| 108 | + return K |
| 109 | +end |
| 110 | + |
| 111 | +function kernelmatrix(k::KernelTensorProduct, x::AbstractVector) |
| 112 | + validate_domain(k, x) |
| 113 | + return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x)) |
| 114 | +end |
| 115 | + |
| 116 | +function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector) |
| 117 | + validate_domain(k, x) |
| 118 | + return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y)) |
| 119 | +end |
| 120 | + |
| 121 | +function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector) |
| 122 | + validate_domain(k, x) |
| 123 | + return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x)) |
| 124 | +end |
| 125 | + |
| 126 | +Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0) |
| 127 | + |
| 128 | +function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct) |
| 129 | + return ( |
| 130 | + length(x.kernels) == length(y.kernels) && |
| 131 | + all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) |
| 132 | + ) |
| 133 | +end |
| 134 | + |
| 135 | +function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int) |
| 136 | + print(io, "Tensor product of ", length(kernel), " kernels:") |
| 137 | + for k in kernel.kernels |
| 138 | + print(io, "\n") |
| 139 | + for _ in 1:(shift + 1) |
| 140 | + print(io, "\t") |
| 141 | + end |
| 142 | + printshifted(io, k, shift + 2) |
| 143 | + end |
| 144 | +end |
0 commit comments