diff options
author | Joel Sing <jsing@cvs.openbsd.org> | 2017-03-07 13:22:40 +0000 |
---|---|---|
committer | Joel Sing <jsing@cvs.openbsd.org> | 2017-03-07 13:22:40 +0000 |
commit | 64d12e7429353055cb033905d75792fc898aa051 (patch) | |
tree | ccc8db1bdd4d1715e7f4b8cebe5e450abfb563e4 | |
parent | 9c259ff9d5a0e40aba3cf22230b44885510912ab (diff) |
Add a test that covers a libtls client talking to a Go TLS server with
varying minimum and maximum protocol versions. This gives us protocol
version test coverage against an independent TLS stack.
-rw-r--r-- | regress/lib/libtls/gotls/tls_test.go | 112 |
1 files changed, 107 insertions, 5 deletions
diff --git a/regress/lib/libtls/gotls/tls_test.go b/regress/lib/libtls/gotls/tls_test.go index f48be5dddac..077dd86e82c 100644 --- a/regress/lib/libtls/gotls/tls_test.go +++ b/regress/lib/libtls/gotls/tls_test.go @@ -1,6 +1,7 @@ package tls import ( + "crypto/tls" "encoding/pem" "fmt" "io/ioutil" @@ -24,6 +25,12 @@ var ( certNotAfter = certNotBefore.Add(1000000 * time.Hour) ) +type handshakeError string + +func (he handshakeError) Error() string { + return string(he) +} + // createCAFile writes a PEM encoded version of the certificate out to a // temporary file, for use by libtls. func createCAFile(cert []byte) (string, error) { @@ -42,14 +49,16 @@ func createCAFile(cert []byte) (string, error) { return f.Name(), nil } -func newTestServer() (*httptest.Server, *url.URL, string, error) { - ts := httptest.NewTLSServer( +func newTestServer(tlsCfg *tls.Config) (*httptest.Server, *url.URL, string, error) { + ts := httptest.NewUnstartedServer( http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, httpContent) }, ), ) + ts.TLS = tlsCfg + ts.StartTLS() u, err := url.Parse(ts.URL) if err != nil { @@ -64,8 +73,57 @@ func newTestServer() (*httptest.Server, *url.URL, string, error) { return ts, u, caFile, nil } +func handshakeVersionTest(tlsCfg *tls.Config) (ProtocolVersion, error) { + ts, u, caFile, err := newTestServer(tlsCfg) + if err != nil { + return 0, fmt.Errorf("failed to start test server: %v", err) + } + defer os.Remove(caFile) + defer ts.Close() + + if err := Init(); err != nil { + return 0, err + } + + cfg, err := NewConfig() + if err != nil { + return 0, err + } + defer cfg.Free() + if err := cfg.SetCAFile(caFile); err != nil { + return 0, err + } + if err := cfg.SetCiphers("compat"); err != nil { + return 0, err + } + if err := cfg.SetProtocols(ProtocolsAll); err != nil { + return 0, err + } + + tls, err := NewClient(cfg) + if err != nil { + return 0, err + } + defer tls.Free() + + if err := tls.Connect(u.Host, ""); err != nil { + return 0, err + } + if err := tls.Handshake(); err != nil { + return 0, handshakeError(err.Error()) + } + version, err := tls.ConnVersion() + if err != nil { + return 0, err + } + if err := tls.Close(); err != nil { + return 0, err + } + return version, nil +} + func TestTLSBasic(t *testing.T) { - ts, u, caFile, err := newTestServer() + ts, u, caFile, err := newTestServer(nil) if err != nil { t.Fatalf("Failed to start test server: %v", err) } @@ -120,8 +178,52 @@ func TestTLSBasic(t *testing.T) { } } +func TestTLSVersions(t *testing.T) { + tests := []struct { + minVersion uint16 + maxVersion uint16 + wantVersion ProtocolVersion + wantHandshakeErr bool + }{ + {tls.VersionSSL30, tls.VersionTLS12, ProtocolTLSv12, false}, + {tls.VersionTLS10, tls.VersionTLS12, ProtocolTLSv12, false}, + {tls.VersionTLS11, tls.VersionTLS12, ProtocolTLSv12, false}, + {tls.VersionSSL30, tls.VersionTLS11, ProtocolTLSv11, false}, + {tls.VersionSSL30, tls.VersionTLS10, ProtocolTLSv10, false}, + {tls.VersionSSL30, tls.VersionSSL30, 0, true}, + {tls.VersionTLS10, tls.VersionTLS10, ProtocolTLSv10, false}, + {tls.VersionTLS11, tls.VersionTLS11, ProtocolTLSv11, false}, + {tls.VersionTLS12, tls.VersionTLS12, ProtocolTLSv12, false}, + } + for i, test := range tests { + t.Logf("Testing handshake with protocols %x:%x", test.minVersion, test.maxVersion) + tlsCfg := &tls.Config{ + MinVersion: test.minVersion, + MaxVersion: test.maxVersion, + } + version, err := handshakeVersionTest(tlsCfg) + switch { + case test.wantHandshakeErr && err == nil: + t.Errorf("Test %d - handshake %x:%x succeeded, want handshake error", + i, test.minVersion, test.maxVersion) + case test.wantHandshakeErr && err != nil: + if _, ok := err.(handshakeError); !ok { + t.Errorf("Test %d - handshake %x:%x; got unknown error, want handshake error: %v", + i, test.minVersion, test.maxVersion, err) + } + case !test.wantHandshakeErr && err != nil: + t.Errorf("Test %d - handshake %x:%x failed: %v", i, test.minVersion, test.maxVersion, err) + case !test.wantHandshakeErr && err == nil: + if got, want := version, test.wantVersion; got != want { + t.Errorf("Test %d - handshake %x:%x; got protocol version %v, want %v", + i, test.minVersion, test.maxVersion, got, want) + } + } + } +} + func TestTLSSingleByteReadWrite(t *testing.T) { - ts, u, caFile, err := newTestServer() + ts, u, caFile, err := newTestServer(nil) if err != nil { t.Fatalf("Failed to start test server: %v", err) } @@ -190,7 +292,7 @@ func TestTLSSingleByteReadWrite(t *testing.T) { } func TestTLSInfo(t *testing.T) { - ts, u, caFile, err := newTestServer() + ts, u, caFile, err := newTestServer(nil) if err != nil { t.Fatalf("Failed to start test server: %v", err) } |