KernelFunctions.jl provide a flexible and complete framework for kernel functions, pretransforming the input data.
The aim is to make the API as model-agnostic as possible while still being user-friendly.
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
# Set simple scaling of the data
k₁ = SqExponentialKernel(1.0)
K₁ = kernelmatrix(k₁,X,obsdim=1)
# Set a function transformation on the data
k₂ = MaternKernel(FunctionTransform(x->sin.(x)))
K₂ = kernelmatrix(k₂,X,obsdim=1)
# Set a matrix premultiplication on the data
k₃ = PolynomialKernel(LowRankTransform(randn(4,1)),2.0,0.0)
K₃ = kernelmatrix(k₃,X,obsdim=1)
# Add and sum kernels
k₄ = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*k₂
K₄ = kernelmatrix(k₄,X,obsdim=1)
plot(heatmap.([K₁,K₂,K₃,K₄],yflip=true,colorbar=false)...,layout=(2,2),title=["K₁" "K₂" "K₃" "K₄"])
- Ensure AD Compatibility (already the case for Zygote, ForwardDiff)
- Toeplitz Matrices compatibility
- BLAS backend
Directly inspired by the MLKernels package.
If you notice a problem or would like to contribute by adding more kernel functions or features please submit an issue.