From 27e2def5f0390a9f8d1e059c83783f7d2abd0019 Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Sun, 10 Jul 2022 14:50:26 +0800
Subject: [PATCH] Refactor SSH init code, fix directory creation for
 TrustedUserCAKeys file (#20299)

* Refactor SSH init code, fix directory creation for TrustedUserCAKeys file

* Update modules/ssh/init.go

Co-authored-by: zeripath <art27@cantab.net>

* fix lint copyright

* Update modules/ssh/init.go

Co-authored-by: zeripath <art27@cantab.net>
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
---
 modules/setting/setting.go  | 21 +++-----------
 modules/ssh/init.go         | 55 +++++++++++++++++++++++++++++++++++++
 modules/ssh/ssh_graceful.go |  4 +--
 routers/init.go             | 12 ++------
 4 files changed, 63 insertions(+), 29 deletions(-)
 create mode 100644 modules/ssh/init.go

diff --git a/modules/setting/setting.go b/modules/setting/setting.go
index 1bd9e09a7a..23e3280dc9 100644
--- a/modules/setting/setting.go
+++ b/modules/setting/setting.go
@@ -840,8 +840,9 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
 		SSH.StartBuiltinServer = false
 	}
 
-	trustedUserCaKeys := sec.Key("SSH_TRUSTED_USER_CA_KEYS").Strings(",")
-	for _, caKey := range trustedUserCaKeys {
+	SSH.TrustedUserCAKeysFile = sec.Key("SSH_TRUSTED_USER_CA_KEYS_FILENAME").MustString(filepath.Join(SSH.RootPath, "gitea-trusted-user-ca-keys.pem"))
+
+	for _, caKey := range SSH.TrustedUserCAKeys {
 		pubKey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(caKey))
 		if err != nil {
 			log.Fatal("Failed to parse TrustedUserCaKeys: %s %v", caKey, err)
@@ -849,7 +850,7 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
 
 		SSH.TrustedUserCAKeysParsed = append(SSH.TrustedUserCAKeysParsed, pubKey)
 	}
-	if len(trustedUserCaKeys) > 0 {
+	if len(SSH.TrustedUserCAKeys) > 0 {
 		// Set the default as email,username otherwise we can leave it empty
 		sec.Key("SSH_AUTHORIZED_PRINCIPALS_ALLOW").MustString("username,email")
 	} else {
@@ -858,20 +859,6 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
 
 	SSH.AuthorizedPrincipalsAllow, SSH.AuthorizedPrincipalsEnabled = parseAuthorizedPrincipalsAllow(sec.Key("SSH_AUTHORIZED_PRINCIPALS_ALLOW").Strings(","))
 
-	if !SSH.Disabled && !SSH.StartBuiltinServer {
-		if err = os.MkdirAll(SSH.KeyTestPath, 0o644); err != nil {
-			log.Fatal("Failed to create '%s': %v", SSH.KeyTestPath, err)
-		}
-
-		if len(trustedUserCaKeys) > 0 && SSH.AuthorizedPrincipalsEnabled {
-			fname := sec.Key("SSH_TRUSTED_USER_CA_KEYS_FILENAME").MustString(filepath.Join(SSH.RootPath, "gitea-trusted-user-ca-keys.pem"))
-			if err := os.WriteFile(fname,
-				[]byte(strings.Join(trustedUserCaKeys, "\n")), 0o600); err != nil {
-				log.Fatal("Failed to create '%s': %v", fname, err)
-			}
-		}
-	}
-
 	SSH.MinimumKeySizeCheck = sec.Key("MINIMUM_KEY_SIZE_CHECK").MustBool(SSH.MinimumKeySizeCheck)
 	minimumKeySizes := Cfg.Section("ssh.minimum_key_sizes").Keys()
 	for _, key := range minimumKeySizes {
diff --git a/modules/ssh/init.go b/modules/ssh/init.go
new file mode 100644
index 0000000000..f6332bb18b
--- /dev/null
+++ b/modules/ssh/init.go
@@ -0,0 +1,55 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"fmt"
+	"net"
+	"os"
+	"path/filepath"
+	"strconv"
+	"strings"
+
+	"code.gitea.io/gitea/modules/log"
+	"code.gitea.io/gitea/modules/setting"
+)
+
+func Init() error {
+	if setting.SSH.Disabled {
+		return nil
+	}
+
+	if setting.SSH.StartBuiltinServer {
+		Listen(setting.SSH.ListenHost, setting.SSH.ListenPort, setting.SSH.ServerCiphers, setting.SSH.ServerKeyExchanges, setting.SSH.ServerMACs)
+		log.Info("SSH server started on %s. Cipher list (%v), key exchange algorithms (%v), MACs (%v)",
+			net.JoinHostPort(setting.SSH.ListenHost, strconv.Itoa(setting.SSH.ListenPort)),
+			setting.SSH.ServerCiphers, setting.SSH.ServerKeyExchanges, setting.SSH.ServerMACs,
+		)
+		return nil
+	}
+
+	builtinUnused()
+
+	// FIXME: why 0o644 for a directory .....
+	if err := os.MkdirAll(setting.SSH.KeyTestPath, 0o644); err != nil {
+		return fmt.Errorf("failed to create directory %q for ssh key test: %w", setting.SSH.KeyTestPath, err)
+	}
+
+	if len(setting.SSH.TrustedUserCAKeys) > 0 && setting.SSH.AuthorizedPrincipalsEnabled {
+		caKeysFileName := setting.SSH.TrustedUserCAKeysFile
+		caKeysFileDir := filepath.Dir(caKeysFileName)
+
+		err := os.MkdirAll(caKeysFileDir, 0o700) // SSH.RootPath by default (That is `~/.ssh` in most cases)
+		if err != nil {
+			return fmt.Errorf("failed to create directory %q for ssh trusted ca keys: %w", caKeysFileDir, err)
+		}
+
+		if err := os.WriteFile(caKeysFileName, []byte(strings.Join(setting.SSH.TrustedUserCAKeys, "\n")), 0o600); err != nil {
+			return fmt.Errorf("failed to write ssh trusted ca keys to %q: %w", caKeysFileName, err)
+		}
+	}
+
+	return nil
+}
diff --git a/modules/ssh/ssh_graceful.go b/modules/ssh/ssh_graceful.go
index 98fe17b3bc..9b91baf09e 100644
--- a/modules/ssh/ssh_graceful.go
+++ b/modules/ssh/ssh_graceful.go
@@ -29,7 +29,7 @@ func listen(server *ssh.Server) {
 	log.Info("SSH Listener: %s Closed", server.Addr)
 }
 
-// Unused informs our cleanup routine that we will not be using a ssh port
-func Unused() {
+// builtinUnused informs our cleanup routine that we will not be using a ssh port
+func builtinUnused() {
 	graceful.GetManager().InformCleanup()
 }
diff --git a/routers/init.go b/routers/init.go
index 2898c44607..72ccf3526c 100644
--- a/routers/init.go
+++ b/routers/init.go
@@ -6,10 +6,8 @@ package routers
 
 import (
 	"context"
-	"net"
 	"reflect"
 	"runtime"
-	"strconv"
 
 	"code.gitea.io/gitea/models"
 	asymkey_model "code.gitea.io/gitea/models/asymkey"
@@ -158,14 +156,8 @@ func GlobalInitInstalled(ctx context.Context) {
 
 	mustInitCtx(ctx, syncAppPathForGit)
 
-	if setting.SSH.StartBuiltinServer {
-		ssh.Listen(setting.SSH.ListenHost, setting.SSH.ListenPort, setting.SSH.ServerCiphers, setting.SSH.ServerKeyExchanges, setting.SSH.ServerMACs)
-		log.Info("SSH server started on %s. Cipher list (%v), key exchange algorithms (%v), MACs (%v)",
-			net.JoinHostPort(setting.SSH.ListenHost, strconv.Itoa(setting.SSH.ListenPort)),
-			setting.SSH.ServerCiphers, setting.SSH.ServerKeyExchanges, setting.SSH.ServerMACs)
-	} else {
-		ssh.Unused()
-	}
+	mustInit(ssh.Init)
+
 	auth.Init()
 	svg.Init()
 }