Skip to content

Commit 618e332

Browse files
committed
feat: Generate SSH server keys in host agent and use them in guest OS
This change changes the SSH server keys that have been generated for each boot in guest OS to be generated by hostagent for each boot. This allows the hostagent to obtain the public key before booting, so that knownhosts can be used with an ssh connection. The code that uses `ssh.InsecureIgnoreHostKey()` in `x/crypto/ssh` is pointed out in CodeQL as `Use of insecure HostKeyCallback implementation (High)`, so it is an implementation to avoid this. Signed-off-by: Norio Nomura <[email protected]>
1 parent ac1c301 commit 618e332

File tree

10 files changed

+152
-13
lines changed

10 files changed

+152
-13
lines changed

pkg/cidata/cidata.TEMPLATE.d/user-data

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,11 @@ bootcmd:
104104
{{- end }}
105105
{{- end }}
106106
{{- end }}
107+
108+
{{- if .SSHHostKeys }}
109+
ssh_keys:
110+
{{- range $type, $key := .SSHHostKeys }}
111+
{{ $type }}: |
112+
{{ indent 4 $key }}
113+
{{- end }}
114+
{{- end }}

pkg/cidata/cidata.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func setupEnv(instConfigEnv map[string]string, propagateProxyEnv bool, slirpGate
118118
return env, nil
119119
}
120120

121-
func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt bool) (*TemplateArgs, error) {
121+
func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt, hostKeys bool) (*TemplateArgs, error) {
122122
if err := limayaml.Validate(instConfig, false); err != nil {
123123
return nil, err
124124
}
@@ -342,11 +342,19 @@ func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, i
342342
}
343343
}
344344

345+
if hostKeys {
346+
sshHostKeys, err := sshutil.GenerateSSHHostKeys(instDir, args.Hostname)
347+
if err != nil {
348+
return nil, fmt.Errorf("failed to generate SSH host keys: %w", err)
349+
}
350+
args.SSHHostKeys = sshHostKeys
351+
}
352+
345353
return &args, nil
346354
}
347355

348356
func GenerateCloudConfig(ctx context.Context, instDir, name string, instConfig *limatype.LimaYAML) error {
349-
args, err := templateArgs(ctx, false, instDir, name, instConfig, 0, 0, 0, "", false, false, false)
357+
args, err := templateArgs(ctx, false, instDir, name, instConfig, 0, 0, 0, "", false, false, false, false)
350358
if err != nil {
351359
return err
352360
}
@@ -369,7 +377,7 @@ func GenerateCloudConfig(ctx context.Context, instDir, name string, instConfig *
369377
}
370378

371379
func GenerateISO9660(ctx context.Context, drv driver.Driver, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort int, guestAgentBinary, nerdctlArchive string, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt bool) error {
372-
args, err := templateArgs(ctx, true, instDir, name, instConfig, udpDNSLocalPort, tcpDNSLocalPort, vsockPort, virtioPort, noCloudInit, rosettaEnabled, rosettaBinFmt)
380+
args, err := templateArgs(ctx, true, instDir, name, instConfig, udpDNSLocalPort, tcpDNSLocalPort, vsockPort, virtioPort, noCloudInit, rosettaEnabled, rosettaBinFmt, true)
373381
if err != nil {
374382
return err
375383
}
@@ -467,6 +475,13 @@ func GenerateISO9660(ctx context.Context, drv driver.Driver, instDir, name strin
467475
Path: "ssh_authorized_keys",
468476
Reader: strings.NewReader(strings.Join(args.SSHPubKeys, "\n")),
469477
})
478+
for keyType, keyContent := range args.SSHHostKeys {
479+
suffix := strings.Replace(strings.Replace(keyType, "_public", "_key.pub", 1), "_private", "_key", 1)
480+
layout = append(layout, iso9660util.Entry{
481+
Path: "ssh_host_" + suffix,
482+
Reader: strings.NewReader(keyContent),
483+
})
484+
}
470485
return writeCIDataDir(filepath.Join(instDir, filenames.CIDataISODir), layout)
471486
}
472487

pkg/cidata/template.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ type TemplateArgs struct {
115115
Plain bool
116116
TimeZone string
117117
NoCloudInit bool
118+
SSHHostKeys map[string]string // `ssh_keys` field in cloud-init SSH module
118119
}
119120

120121
func ValidateTemplateArgs(args *TemplateArgs) error {

pkg/driver/vz/vsock_forwarder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,5 @@ func (m *virtualMachineWrapper) dialVsock(_ context.Context, port uint32) (conn
7474
func (m *virtualMachineWrapper) checkSSHOverVsockAvailable(ctx context.Context, inst *limatype.Instance) error {
7575
return sshutil.WaitSSHReady(ctx, func(ctx context.Context) (net.Conn, error) {
7676
return m.dialVsock(ctx, uint32(22))
77-
}, "vsock:22", *inst.Config.User.Name, 1)
77+
}, "vsock:22", *inst.Config.User.Name, inst.Name, 1)
7878
}

pkg/driver/wsl2/boot/02-no-cloud-init-setup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ chmod 700 "${LIMA_CIDATA_HOME}"/.ssh/
1717
cp "${LIMA_CIDATA_MNT}"/ssh_authorized_keys "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
1818
chown "${LIMA_CIDATA_UID}:${LIMA_CIDATA_GID}" "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
1919
chmod 600 "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
20+
# copy SSH host keys
21+
mkdir -p /etc/ssh/
22+
cp "${LIMA_CIDATA_MNT}"/ssh_host_* /etc/ssh/
23+
chmod 600 /etc/ssh/ssh_host_*
24+
chmod 644 /etc/ssh/ssh_host_*.pub
2025

2126
# add $LIMA_CIDATA_USER to sudoers
2227
echo "${LIMA_CIDATA_USER} ALL=(ALL) NOPASSWD:ALL" | tee -a /etc/sudoers.d/99_lima_sudoers

pkg/limatype/filenames/filenames.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const (
5050
SerialVirtioSock = "serialv.sock"
5151
SSHSock = "ssh.sock"
5252
SSHConfig = "ssh.config"
53+
SSHKnownHosts = "ssh_known_hosts"
5354
VhostSock = "virtiofsd-%d.sock"
5455
VNCDisplayFile = "vncdisplay"
5556
VNCPasswordFile = "vncpassword"

pkg/networks/usernet/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ func (c *Client) WaitOpeningSSHPort(ctx context.Context, inst *limatype.Instance
141141
return err
142142
}
143143
user := *inst.Config.User.Name
144+
instanceName := inst.Name
144145
// -1 avoids both sides timing out simultaneously.
145-
u := fmt.Sprintf("%s/extension/wait-ssh-server?ip=%s&port=22&timeout=%d&user=%s", c.base, ipAddr, timeoutSeconds-1, user)
146+
u := fmt.Sprintf("%s/extension/wait-ssh-server?ip=%s&port=22&timeout=%d&user=%s&instance-name=%s", c.base, ipAddr, timeoutSeconds-1, user, instanceName)
146147
res, err := httpclientutil.Get(ctx, c.client, u)
147148
if err != nil {
148149
return err

pkg/networks/usernet/gvproxy.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,9 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
260260
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", uint16(port16)))
261261

262262
user := r.URL.Query().Get("user")
263-
if user == "" {
264-
msg := "user query parameter is required"
263+
instanceName := r.URL.Query().Get("instance-name")
264+
if user == "" || instanceName == "" {
265+
msg := "user and instanceName query parameters are required"
265266
http.Error(w, msg, http.StatusBadRequest)
266267
return
267268
}
@@ -279,7 +280,7 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
279280
return n.DialContextTCP(ctx, addr)
280281
}
281282
// Wait until the port is available.
282-
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, timeoutSeconds); err != nil {
283+
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, instanceName, timeoutSeconds); err != nil {
283284
http.Error(w, err.Error(), http.StatusRequestTimeout)
284285
} else {
285286
w.WriteHeader(http.StatusOK)

pkg/sshutil/sshutil.go

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@ package sshutil
66
import (
77
"bytes"
88
"context"
9+
"crypto"
10+
"crypto/ecdsa"
11+
"crypto/ed25519"
12+
"crypto/elliptic"
13+
"crypto/rand"
14+
"crypto/rsa"
915
"encoding/base64"
1016
"encoding/binary"
17+
"encoding/pem"
1118
"errors"
1219
"fmt"
20+
"io"
1321
"io/fs"
1422
"net"
1523
"os"
@@ -26,8 +34,10 @@ import (
2634
"github.com/mattn/go-shellwords"
2735
"github.com/sirupsen/logrus"
2836
"golang.org/x/crypto/ssh"
37+
"golang.org/x/crypto/ssh/knownhosts"
2938
"golang.org/x/sys/cpu"
3039

40+
"github.com/lima-vm/lima/v2/pkg/instance/hostname"
3141
"github.com/lima-vm/lima/v2/pkg/ioutilx"
3242
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
3343
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
@@ -244,7 +254,6 @@ func CommonOpts(ctx context.Context, sshExe SSHExe, useDotSSH bool) ([]string, e
244254

245255
opts = append(opts,
246256
"StrictHostKeyChecking=no",
247-
"UserKnownHostsFile=/dev/null",
248257
"NoHostAuthenticationForLocalhost=yes",
249258
"PreferredAuthentications=publickey",
250259
"Compression=no",
@@ -345,18 +354,28 @@ func SSHOpts(ctx context.Context, sshExe SSHExe, instDir, username string, useDo
345354
return nil, err
346355
}
347356
controlPath := fmt.Sprintf(`ControlPath="%s"`, controlSock)
357+
userKnownHostsPath := filepath.Join(instDir, filenames.SSHKnownHosts)
358+
userKnownHosts := fmt.Sprintf(`UserKnownHostsFile="%s"`, userKnownHostsPath)
348359
if runtime.GOOS == "windows" {
349360
controlSock, err = ioutilx.WindowsSubsystemPath(ctx, controlSock)
350361
if err != nil {
351362
return nil, err
352363
}
353364
controlPath = fmt.Sprintf(`ControlPath='%s'`, controlSock)
365+
userKnownHostsPath, err = ioutilx.WindowsSubsystemPath(ctx, userKnownHostsPath)
366+
if err != nil {
367+
return nil, err
368+
}
369+
userKnownHosts = fmt.Sprintf(`UserKnownHostsFile='%s'`, userKnownHostsPath)
354370
}
371+
hostKeyAlias := fmt.Sprintf("HostKeyAlias=%s", hostname.FromInstName(filepath.Base(instDir)))
355372
opts = append(opts,
356373
fmt.Sprintf("User=%s", username), // guest and host have the same username, but we should specify the username explicitly (#85)
357374
"ControlMaster=auto",
358375
controlPath,
359376
"ControlPersist=yes",
377+
userKnownHosts,
378+
hostKeyAlias,
360379
)
361380
if forwardAgent {
362381
opts = append(opts, "ForwardAgent=yes")
@@ -514,9 +533,9 @@ func detectAESAcceleration() bool {
514533

515534
// WaitSSHReady waits until the SSH server is ready to accept connections.
516535
// The dialContext function is used to create a connection to the SSH server.
517-
// The addr, user, parameter is used for ssh.ClientConn creation.
536+
// The addr, user, instanceName parameter is used for ssh.ClientConn creation.
518537
// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
519-
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user string, timeoutSeconds int) error {
538+
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user, instanceName string, timeoutSeconds int) error {
520539
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
521540
defer cancel()
522541

@@ -525,11 +544,16 @@ func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Co
525544
if err != nil {
526545
return err
527546
}
547+
// Prepare HostKeyCallback
548+
hostKeyChecker, err := HostKeyCheckerWithKeysInKnownHosts(instanceName)
549+
if err != nil {
550+
return err
551+
}
528552
// Prepare ssh client config
529553
sshConfig := &ssh.ClientConfig{
530554
User: user,
531555
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
532-
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
556+
HostKeyCallback: hostKeyChecker,
533557
Timeout: 10 * time.Second,
534558
}
535559
// Wait until the SSH server is available.
@@ -581,3 +605,86 @@ func userPrivateKeySigner() (ssh.Signer, error) {
581605
}
582606
return signer, nil
583607
}
608+
609+
func HostKeyCheckerWithKeysInKnownHosts(instanceName string) (ssh.HostKeyCallback, error) {
610+
publicKeys, err := PublicKeysFromKnownHosts(instanceName)
611+
if err != nil {
612+
return nil, err
613+
}
614+
return func(_ string, _ net.Addr, key ssh.PublicKey) error {
615+
keyBytes := key.Marshal()
616+
for _, pk := range publicKeys {
617+
if bytes.Equal(keyBytes, pk.Marshal()) {
618+
return nil
619+
}
620+
}
621+
return errors.New("ssh: host key mismatch")
622+
}, nil
623+
}
624+
625+
// PublicKeysFromKnownHosts returns the public keys from the known_hosts file located in the instance directory.
626+
func PublicKeysFromKnownHosts(instanceName string) ([]ssh.PublicKey, error) {
627+
// Load known_hosts from the instance directory
628+
instanceDir, err := dirnames.InstanceDir(instanceName)
629+
if err != nil {
630+
return nil, fmt.Errorf("failed to get instance dir for instance %q: %w", instanceName, err)
631+
}
632+
knownHostsPath := filepath.Join(instanceDir, filenames.SSHKnownHosts)
633+
knownHostsBytes, err := os.ReadFile(knownHostsPath)
634+
if err != nil {
635+
return nil, fmt.Errorf("failed to read known_hosts file at %s: %w", knownHostsPath, err)
636+
}
637+
var publicKeys []ssh.PublicKey
638+
rest := knownHostsBytes
639+
for len(rest) > 0 {
640+
var publicKey ssh.PublicKey
641+
publicKey, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
642+
if err != nil {
643+
return nil, fmt.Errorf("failed to parse public key from known_hosts file %s: %w", knownHostsPath, err)
644+
}
645+
publicKeys = append(publicKeys, publicKey)
646+
}
647+
return publicKeys, nil
648+
}
649+
650+
// GenerateSSHHostKeys generates an Ed25519 host key pair for the SSH server.
651+
// The private key is returned in PEM format, and the public key.
652+
func GenerateSSHHostKeys(instDir, hostname string) (map[string]string, error) {
653+
generators := map[string]func(io.Reader) (crypto.PrivateKey, error){
654+
"ecdsa": func(rand io.Reader) (crypto.PrivateKey, error) {
655+
return ecdsa.GenerateKey(elliptic.P256(), rand)
656+
},
657+
"ed25519": func(rand io.Reader) (crypto.PrivateKey, error) {
658+
_, priv, err := ed25519.GenerateKey(rand)
659+
return priv, err
660+
},
661+
"rsa": func(rand io.Reader) (crypto.PrivateKey, error) {
662+
return rsa.GenerateKey(rand, 3072)
663+
},
664+
}
665+
res := make(map[string]string, len(generators))
666+
var sshKnownHosts []byte
667+
for keyType, generator := range generators {
668+
priv, err := generator(rand.Reader)
669+
if err != nil {
670+
return nil, err
671+
}
672+
privPem, err := ssh.MarshalPrivateKey(priv, hostname)
673+
if err != nil {
674+
return nil, fmt.Errorf("failed to marshal %s private key to PEM format: %w", keyType, err)
675+
}
676+
pub, err := ssh.NewPublicKey(priv.(crypto.Signer).Public())
677+
if err != nil {
678+
return nil, fmt.Errorf("failed to create ssh %s public key: %w", keyType, err)
679+
}
680+
res[keyType+"_private"] = string(pem.EncodeToMemory(privPem))
681+
res[keyType+"_public"] = string(ssh.MarshalAuthorizedKey(pub))
682+
sshKnownHosts = append(sshKnownHosts, knownhosts.Line([]string{hostname}, pub)...)
683+
sshKnownHosts = append(sshKnownHosts, '\n')
684+
}
685+
knownHostsPath := filepath.Join(instDir, filenames.SSHKnownHosts)
686+
if err := os.WriteFile(knownHostsPath, sshKnownHosts, 0o644); err != nil {
687+
return nil, fmt.Errorf("failed to write known_hosts file at %s: %w", knownHostsPath, err)
688+
}
689+
return res, nil
690+
}

pkg/textutil/textutil.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616

1717
// ExecuteTemplate executes a text/template template.
1818
func ExecuteTemplate(tmpl string, args any) ([]byte, error) {
19-
x, err := template.New("").Parse(tmpl)
19+
x, err := template.New("").Funcs(TemplateFuncMap).Parse(tmpl)
2020
if err != nil {
2121
return nil, err
2222
}

0 commit comments

Comments
 (0)