Skip to content

Commit 31a37b5

Browse files
committed
restrict which commands can have arguments rewritten
1 parent ea2873b commit 31a37b5

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

main.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,15 @@ func main() {
1818
return
1919
}
2020

21-
// Change any ssh or git url arguments to https
22-
for i, arg := range os.Args {
23-
os.Args[i] = Scrub(arg)
21+
if IsRewriteAllowed(os.Args[1:]) {
22+
// Change any ssh or git url arguments to https
23+
for i, arg := range os.Args[1:] {
24+
os.Args[i] = Scrub(arg)
25+
}
2426
}
2527

2628
// Run the scrubbed git command
27-
var cmd *exec.Cmd
28-
if len(os.Args) == 1 {
29-
cmd = exec.Command(os.Args[0])
30-
} else {
31-
cmd = exec.Command(os.Args[0], os.Args[1:]...)
32-
}
29+
cmd := exec.Command(os.Args[0], os.Args[1:]...)
3330
cmd.Stderr = os.Stderr
3431
cmd.Stdout = os.Stdout
3532
err := cmd.Run()
@@ -39,6 +36,26 @@ func main() {
3936
}
4037
}
4138

39+
var allowedCommands = []string{"clone", "fetch"}
40+
41+
// IsRewriteAllowed returns true if it is safe to rewrite arguments. Some commands
42+
// such as config would break if rewritten, like when using insteadOf.
43+
func IsRewriteAllowed(args []string) bool {
44+
for _, arg := range args {
45+
if strings.HasPrefix(arg, "-") {
46+
continue
47+
}
48+
for _, allowed := range allowedCommands {
49+
if arg == allowed {
50+
return true
51+
}
52+
}
53+
return false
54+
}
55+
return false
56+
}
57+
58+
// FindGit finds the second git executable on the path, the first being this one.
4259
func FindGit(envPath string) string {
4360
paths := strings.Split(envPath, string(os.PathListSeparator))
4461
var shimPath string
@@ -68,6 +85,7 @@ func FindGit(envPath string) string {
6885

6986
var scpUrl = regexp.MustCompile(`^(?P<user>\S+?)@(?P<host>[a-zA-Z\d-]+(\.[a-zA-Z\d-]+)+\.?):(?P<path>.*?/.*?)$`)
7087

88+
// Scrub rewrites arguments that look like URLs to have the HTTPS protocol.
7189
func Scrub(argument string) string {
7290
u, err := url.ParseRequestURI(argument)
7391
if err == nil && u.Scheme != "" {

main_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,38 @@ import (
66
"testing"
77
)
88

9+
func TestIsRewriteAllowed(t *testing.T) {
10+
var cases = []struct {
11+
input []string
12+
expected bool
13+
}{
14+
{
15+
input: []string{"clone", "[email protected]:org/repo"},
16+
expected: true,
17+
},
18+
{
19+
input: []string{"fetch", ""},
20+
expected: true,
21+
},
22+
{
23+
input: []string{"--work-tree=/work", "clone"},
24+
expected: true,
25+
},
26+
{
27+
input: []string{"config", "--global", "url.\"https://github.com/\".insteadOf", "[email protected]:"},
28+
expected: false,
29+
},
30+
}
31+
32+
for _, test := range cases {
33+
t.Run(fmt.Sprintln(test.input), func(t *testing.T) {
34+
if v := IsRewriteAllowed(test.input); v != test.expected {
35+
t.Errorf("Input: %v\tExpected: %v\tGot: %v\n", test.input, test.expected, v)
36+
}
37+
})
38+
}
39+
}
40+
941
func TestScrub(t *testing.T) {
1042
var cases = []struct {
1143
input string

0 commit comments

Comments
 (0)