diff --git a/ssh/client_auth.go b/ssh/client_auth.go index f3265655ee..c43489ac4d 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -201,7 +201,19 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand } var methods []string for _, signer := range signers { - ok, err := validateKey(signer.PublicKey(), user, c) + + // In order for the publicKeyCallback to play nice with custom + // AlgorithmSigners, it needs to know which algorithm the key is signed + // with. In most cases, this is just the key type, but in some special + // cases these won't match. For example, it is valid to sign an ssh-rsa + // key with the algorithm "rsa-sha2-256" + pub := signer.PublicKey() + algoname := pub.Type() + if algoNameSigner, ok := signer.(AlgorithmSignerWithAlgoName); ok { + algoname = algoNameSigner.AlgorithmName() + } + + ok, err := validateKey(pub, user, c, algoname) if err != nil { return authFailure, nil, err } @@ -209,13 +221,12 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand continue } - pub := signer.PublicKey() pubKey := pub.Marshal() sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ User: user, Service: serviceSSH, Method: cb.method(), - }, []byte(pub.Type()), pubKey)) + }, []byte(algoname), pubKey)) if err != nil { return authFailure, nil, err } @@ -229,7 +240,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand Service: serviceSSH, Method: cb.method(), HasSig: true, - Algoname: pub.Type(), + Algoname: algoname, PubKey: pubKey, Sig: sig, } @@ -266,26 +277,25 @@ func containsMethod(methods []string, method string) bool { } // validateKey validates the key provided is acceptable to the server. -func validateKey(key PublicKey, user string, c packetConn) (bool, error) { +func validateKey(key PublicKey, user string, c packetConn, algoname string) (bool, error) { pubKey := key.Marshal() msg := publickeyAuthMsg{ User: user, Service: serviceSSH, Method: "publickey", HasSig: false, - Algoname: key.Type(), + Algoname: algoname, PubKey: pubKey, } if err := c.writePacket(Marshal(&msg)); err != nil { return false, err } - return confirmKeyAck(key, c) + return confirmKeyAck(key, c, algoname) } -func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { +func confirmKeyAck(key PublicKey, c packetConn, algoname string) (bool, error) { pubKey := key.Marshal() - algoname := key.Type() for { packet, err := c.readPacket() diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index 63a8e22487..baacb7e24a 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -272,6 +272,36 @@ func TestMethodInvalidAlgorithm(t *testing.T) { } } +func TestMethodAlgorithmSignerWithAlgoName(t *testing.T) { + algSigner, _ := NewAlgorithmSignerFromSigner(testSigners["rsa"], SigAlgoRSASHA2256) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(algSigner), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + // Once SigAlgoRSASHA2256 is implemented in the ssh server, this test + // will pass, and will need to be updated. + err, serverErrors := tryAuthBothSides(t, config, nil) + if err == nil { + t.Fatalf("login succeeded") + } + + found := false + want := "\"rsa-sha2-256\" not accepted" + + var errStrings []string + for _, err := range serverErrors { + found = found || (err != nil && strings.Contains(err.Error(), want)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q", errStrings, want) + } +} + func TestClientHMAC(t *testing.T) { for _, mac := range supportedMACs { config := &ClientConfig{ diff --git a/ssh/keys.go b/ssh/keys.go index 31f26349a0..2bbe464238 100644 --- a/ssh/keys.go +++ b/ssh/keys.go @@ -333,6 +333,51 @@ type AlgorithmSigner interface { SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) } +// An AlgorithmSignerWithAlgoName is a Signer that also supports specifying a +// specific algorithm to use for signing, and can provide the name of the +// algorithm being used to sign. +type AlgorithmSignerWithAlgoName interface { + AlgorithmSigner + + // AlgorithmName returns the name of the algorithm being used by the + // AlgorithmSigner. + AlgorithmName() string +} + +// algorithmSignerWithAlgoName is a struct that implements the +// AlgorithmSignerWithAlgoName interface. Use NewAlgorithmSignerFromSigner +// to instantiate it outside of the ssh library. +type algorithmSignerWithAlgoName struct { + AlgorithmSigner + algorithm string +} + +func (s *algorithmSignerWithAlgoName) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.SignWithAlgorithm(rand, data, s.algorithm) +} + +func (s *algorithmSignerWithAlgoName) AlgorithmName() string { + return s.algorithm +} + +// NewAlgorithmSignerFromSigner takes any ssh.AlgorithmSigner implementation and +// an algorithm name to sign with, and returns an AlgorithmSignerWithAlgoName. +// This can be used in PublicKeysCallback to set custom algorithms during the +// ssh handshake. +func NewAlgorithmSignerFromSigner(signer Signer, algorithm string) (Signer, error) { + algorithmSigner, ok := signer.(AlgorithmSigner) + if !ok { + return nil, errors.New("unable to cast to ssh.AlgorithmSigner") + } + + s := algorithmSignerWithAlgoName{ + algorithmSigner, + algorithm, + } + + return &s, nil +} + type rsaPublicKey rsa.PublicKey func (r *rsaPublicKey) Type() string {