|
| 1 | +# Linear layer, perform linear transformation to input array |
| 2 | +# x₁ = softplus.(W) * x₀ |
| 3 | +struct LinearLayer{T,MT<:AbstractArray{T}} |
| 4 | + W::MT |
| 5 | +end |
| 6 | +@functor LinearLayer |
| 7 | + |
| 8 | +LinearLayer(in_dim, out_dim) = LinearLayer(randn(out_dim, in_dim)) |
| 9 | + |
| 10 | +(lin::LinearLayer)(x) = softplus.(lin.W) * x |
| 11 | + |
| 12 | +function Base.show(io::IO, layer::LinearLayer) |
| 13 | + return print(io, "LinearLayer(", size(layer.W, 2), ", ", size(layer.W, 1), ")") |
| 14 | +end |
| 15 | + |
| 16 | +# Product function, given an 2d array whose size is M×N, product layer will |
| 17 | +# multiply every m neighboring rows of the array elementwisely to obtain |
| 18 | +# an new array of size (M÷m)×N |
| 19 | +function product(x, step=2) |
| 20 | + m, n = size(x) |
| 21 | + m % step == 0 || error("the first dimension of inputs must be multiple of step") |
| 22 | + new_x = reshape(x, step, m ÷ step, n) |
| 23 | + return .*([new_x[i, :, :] for i in 1:step]...) |
| 24 | +end |
| 25 | + |
| 26 | +# Primitive layer, mainly act as a container to hold basic kernels for the neural kernel network |
| 27 | +struct Primitive{T} |
| 28 | + kernels::T |
| 29 | + Primitive(ks...) = new{typeof(ks)}(ks) |
| 30 | +end |
| 31 | +@functor Primitive |
| 32 | + |
| 33 | +# flatten k kernel matrices of size Mk×Nk, and concatenate these 1d array into a k×(Mk*Nk) 2d array |
| 34 | +_cat_kernel_array(x) = vcat([reshape(x[i], 1, :) for i in 1:length(x)]...) |
| 35 | + |
| 36 | +# NOTE, though we implement `ew` & `pw` function for Primitive, it isn't a subtype of Kernel |
| 37 | +# type, I do this because it will facilitate writing NeuralKernelNetwork |
| 38 | +ew(p::Primitive, x) = _cat_kernel_array(map(k -> kernelmatrix_diag(k, x), p.kernels)) |
| 39 | +pw(p::Primitive, x) = _cat_kernel_array(map(k -> kernelmatrix(k, x), p.kernels)) |
| 40 | + |
| 41 | +function ew(p::Primitive, x, x′) |
| 42 | + return _cat_kernel_array(map(k -> kernelmatrix_diag(k, x, x′), p.kernels)) |
| 43 | +end |
| 44 | +pw(p::Primitive, x, x′) = _cat_kernel_array(map(k -> kernelmatrix(k, x, x′), p.kernels)) |
| 45 | + |
| 46 | +function Base.show(io::IO, layer::Primitive) |
| 47 | + print(io, "Primitive(") |
| 48 | + join(io, layer.kernels, ", ") |
| 49 | + return print(io, ")") |
| 50 | +end |
| 51 | + |
| 52 | +""" |
| 53 | + NeuralKernelNetwork(primitives, nn) |
| 54 | +
|
| 55 | +Constructs a Neural Kernel Network (NKN) [1]. |
| 56 | +
|
| 57 | +`primitives` are the based kernels, combined by `nn`. |
| 58 | +
|
| 59 | +```julia |
| 60 | +k1 = 0.6 * (SEKernel() ∘ ScaleTransform(0.5)) |
| 61 | +k2 = 0.4 * (Matern32Kernel() ∘ ScaleTransform(0.1)) |
| 62 | +primitives = Primitive(k1, k2) |
| 63 | +nkn = NeuralKernelNetwork(primitives, Chain(LinearLayer(2, 2), product)) |
| 64 | +``` |
| 65 | +
|
| 66 | +[1] - Sun, Shengyang, et al. "Differentiable compositional kernel learning for Gaussian |
| 67 | + processes." International Conference on Machine Learning. PMLR, 2018. |
| 68 | +""" |
| 69 | +struct NeuralKernelNetwork{PT,NNT} <: Kernel |
| 70 | + primitives::PT |
| 71 | + nn::NNT |
| 72 | +end |
| 73 | +@functor NeuralKernelNetwork |
| 74 | + |
| 75 | +# use this function to reshape the 1d array back to kernel matrix |
| 76 | +_rebuild_kernel(x, n, m) = reshape(x, n, m) |
| 77 | +_rebuild_diag(x) = reshape(x, :) |
| 78 | + |
| 79 | +(κ::NeuralKernelNetwork)(x, y) = only(kernelmatrix(κ, [x], [y])) |
| 80 | + |
| 81 | +function kernelmatrix_diag(nkn::NeuralKernelNetwork, x::AbstractVector) |
| 82 | + return _rebuild_diag(nkn.nn(ew(nkn.primitives, x))) |
| 83 | +end |
| 84 | + |
| 85 | +function kernelmatrix(nkn::NeuralKernelNetwork, x::AbstractVector) |
| 86 | + return _rebuild_kernel(nkn.nn(pw(nkn.primitives, x)), length(x), length(x)) |
| 87 | +end |
| 88 | + |
| 89 | +function kernelmatrix_diag(nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector) |
| 90 | + return _rebuild_diag(nkn.nn(ew(nkn.primitives, x, x′))) |
| 91 | +end |
| 92 | + |
| 93 | +function kernelmatrix(nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector) |
| 94 | + return _rebuild_kernel(nkn.nn(pw(nkn.primitives, x, x′)), length(x), length(x′)) |
| 95 | +end |
| 96 | + |
| 97 | +function kernelmatrix_diag!(K::AbstractVector, nkn::NeuralKernelNetwork, x::AbstractVector) |
| 98 | + K .= kernelmatrix_diag(nkn, x) |
| 99 | + return K |
| 100 | +end |
| 101 | + |
| 102 | +function kernelmatrix!(K::AbstractMatrix, nkn::NeuralKernelNetwork, x::AbstractVector) |
| 103 | + K .= kernelmatrix(nkn, x) |
| 104 | + return K |
| 105 | +end |
| 106 | + |
| 107 | +function kernelmatrix_diag!( |
| 108 | + K::AbstractVector, nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector |
| 109 | +) |
| 110 | + K .= kernelmatrix_diag(nkn, x, x′) |
| 111 | + return K |
| 112 | +end |
| 113 | + |
| 114 | +function kernelmatrix!( |
| 115 | + K::AbstractMatrix, nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector |
| 116 | +) |
| 117 | + K .= kernelmatrix(nkn, x, x′) |
| 118 | + return K |
| 119 | +end |
| 120 | + |
| 121 | +function Base.show(io::IO, kernel::NeuralKernelNetwork) |
| 122 | + print(io, "NeuralKernelNetwork(") |
| 123 | + join(io, [kernel.primitives, kernel.nn], ", ") |
| 124 | + return print(io, ")") |
| 125 | +end |
0 commit comments