@@ -19,6 +19,7 @@ kernelmatrix
1919
2020"""
2121 kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X; obsdim::Int = 2)
22+ kerneldiagmatrix!(K::AbstractVector, κ::Kernel, X, Y; obsdim::Int = 2)
2223
2324In place version of [`kerneldiagmatrix`](@ref)
2425"""
@@ -30,6 +31,11 @@ kerneldiagmatrix!
3031Calculate the diagonal matrix of `X` with respect to kernel `κ`
3132`obsdim = 1` means the matrix `X` has size #samples x #dimension
3233`obsdim = 2` means the matrix `X` has size #dimension x #samples
34+
35+ kerneldiagmatrix(κ::Kernel, X, Y; obsdim::Int = 2)
36+
37+ Calculate the diagonal of `kernelmatrix(κ, X, Y; obsdim)` efficiently. Requires that `X` and
38+ `Y` are the same length.
3339"""
3440kerneldiagmatrix
3541
@@ -59,8 +65,16 @@ function kerneldiagmatrix!(K::AbstractVector, κ::Kernel, x::AbstractVector)
5965 return map! (x -> κ (x, x), K, x)
6066end
6167
68+ function kerneldiagmatrix! (
69+ K:: AbstractVector , κ:: Kernel , x:: AbstractVector , y:: AbstractVector ,
70+ )
71+ return map! (κ, x, y)
72+ end
73+
6274kerneldiagmatrix (κ:: Kernel , x:: AbstractVector ) = map (x -> κ (x, x), x)
6375
76+ kerneldiagmatrix (κ:: Kernel , x:: AbstractVector , y:: AbstractVector ) = map (κ, x, y)
77+
6478
6579
6680#
99113const defaultobs = 2
100114
101115function kernelmatrix! (
102- K:: AbstractMatrix , κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs
116+ K:: AbstractMatrix , κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs,
103117)
104118 return kernelmatrix! (K, κ, vec_of_vecs (X; obsdim= obsdim))
105119end
106120
107121function kernelmatrix! (
108122 K:: AbstractMatrix , κ:: Kernel , X:: AbstractMatrix , Y:: AbstractMatrix ;
109- obsdim:: Int = defaultobs
123+ obsdim:: Int = defaultobs,
110124)
111- x = vec_of_vecs (X; obsdim= obsdim)
112- y = vec_of_vecs (Y; obsdim= obsdim)
113- return kernelmatrix! (K, κ, x, y)
125+ return kernelmatrix! (K, κ, vec_of_vecs (X; obsdim= obsdim), vec_of_vecs (Y; obsdim= obsdim))
114126end
115127
116- function kernelmatrix (κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
128+ function kernelmatrix (κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
117129 return kernelmatrix (κ, vec_of_vecs (X; obsdim= obsdim))
118130end
119131
120132function kernelmatrix (κ:: Kernel , X:: AbstractMatrix , Y:: AbstractMatrix ; obsdim= defaultobs)
121- x = vec_of_vecs (X; obsdim= obsdim)
122- y = vec_of_vecs (Y; obsdim= obsdim)
123- return kernelmatrix (κ, x, y)
133+ return kernelmatrix (κ, vec_of_vecs (X; obsdim= obsdim), vec_of_vecs (Y; obsdim= obsdim))
124134end
125135
126136function kerneldiagmatrix! (
127- K:: AbstractVector , κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs
137+ K:: AbstractVector , κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs
128138)
129139 return kerneldiagmatrix! (K, κ, vec_of_vecs (X; obsdim= obsdim))
130140end
131141
132- function kerneldiagmatrix (κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
142+ function kerneldiagmatrix! (
143+ K:: AbstractVector , κ:: Kernel , X:: AbstractMatrix , Y:: AbstractMatrix ;
144+ obsdim:: Int = defaultobs,
145+ )
146+ return kerneldiagmatrix! (
147+ K, κ, vec_of_vecs (X; obsdim= obsdim), vec_of_vecs (Y; obsdim= obsdim),
148+ )
149+ end
150+
151+ function kerneldiagmatrix (κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
133152 return kerneldiagmatrix (κ, vec_of_vecs (X; obsdim= obsdim))
134153end
154+
155+ function kerneldiagmatrix (
156+ κ:: Kernel , X:: AbstractMatrix , Y:: AbstractMatrix ; obsdim:: Int = defaultobs,
157+ )
158+ return kerneldiagmatrix (κ, vec_of_vecs (X; obsdim= obsdim), vec_of_vecs (Y; obsdim= obsdim))
159+ end
0 commit comments