Skip to content

Commit

Permalink
PR feedback + remove CloseIdleConnections
Browse files Browse the repository at this point in the history
  • Loading branch information
harold-s committed Sep 13, 2024
1 parent f4ef674 commit 85fea80
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 96 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,4 @@ See [Development.md](Development.md)
- Evan Broder
- Marc-André Tremblay
- Ryan Koppenhaver
- Harold Simpson
85 changes: 45 additions & 40 deletions pkg/smokescreen/acl/v1/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ type Rule struct {
}

type MitmDomain struct {
MitmConfig
Domain string
AddHeaders map[string]string
DetailedHttpLogs bool
DetailedHttpLogsFullHeaders []string
Domain string
}

type MitmConfig struct {
Expand Down Expand Up @@ -142,8 +144,12 @@ func (acl *ACL) Decide(service, host, connectProxyHost string) (Decision, error)
// if the host matches any of the rule's allowed domains with MITM config, allow
for _, dg := range rule.DomainMitmGlobs {
if HostMatchesGlob(host, dg.Domain) {
d.Result, d.Reason = Allow, "host matched allowed domain in rule"
d.MitmConfig = (*MitmConfig)(&dg.MitmConfig)
d.Result, d.Reason = Allow, "host matched allowed domain in MITM rule"
d.MitmConfig = &MitmConfig{
AddHeaders: dg.AddHeaders,
DetailedHttpLogs: dg.DetailedHttpLogs,
DetailedHttpLogsFullHeaders: dg.DetailedHttpLogsFullHeaders,
}
return d, nil
}
}
Expand Down Expand Up @@ -215,59 +221,58 @@ func (acl *ACL) Validate() error {
}

func (acl *ACL) ValidateRuleDomainsGlobs(svc string, r Rule) error {
err := acl.ValidateDomainGlobs(svc, r.DomainGlobs)
if err != nil {
return err
}
mitmDomainGlobs := make([]string, len(r.DomainMitmGlobs))
for i, d := range r.DomainMitmGlobs {
mitmDomainGlobs[i] = d.Domain
var err error
for _, d := range r.DomainGlobs {
err = acl.ValidateDomainGlob(svc, d)
if err != nil {
return err
}
}
err = acl.ValidateDomainGlobs(svc, mitmDomainGlobs)
if err != nil {
return err
for _, d := range r.DomainMitmGlobs {
err = acl.ValidateDomainGlob(svc, d.Domain)
if err != nil {
return err
}
}
return nil
}

// ValidateDomainGlobs takes a slice of domain globs and verifies they conform to smokescreen's
// ValidateDomainGlob takes a domain glob and verifies they conform to smokescreen's
// domain glob policy.
//
// Wildcards are valid only at the beginning of a domain glob, and only a single wildcard per glob
// pattern is allowed. Globs must include text after a wildcard.
//
// Domains must use their normalized form (e.g., Punycode)
func (acl *ACL) ValidateDomainGlobs(svc string, globs []string) error {
for _, glob := range globs {
if glob == "" {
return fmt.Errorf("glob cannot be empty")
}
func (*ACL) ValidateDomainGlob(svc string, glob string) error {
if glob == "" {
return fmt.Errorf("glob cannot be empty")
}

if glob == "*" || glob == "*." {
return fmt.Errorf("%v: %v: domain glob must not match everything", svc, glob)
}
if glob == "*" || glob == "*." {
return fmt.Errorf("%v: %v: domain glob must not match everything", svc, glob)
}

if !strings.HasPrefix(glob, "*.") && strings.HasPrefix(glob, "*") {
return fmt.Errorf("%v: %v: domain glob must represent a full prefix (sub)domain", svc, glob)
}
if !strings.HasPrefix(glob, "*.") && strings.HasPrefix(glob, "*") {
return fmt.Errorf("%v: %v: domain glob must represent a full prefix (sub)domain", svc, glob)
}

domainToCheck := strings.TrimPrefix(glob, "*")
if strings.Contains(domainToCheck, "*") {
return fmt.Errorf("%v: %v: domain globs are only supported as prefix", svc, glob)
}
domainToCheck := strings.TrimPrefix(glob, "*")
if strings.Contains(domainToCheck, "*") {
return fmt.Errorf("%v: %v: domain globs are only supported as prefix", svc, glob)
}

normalizedDomain, err := hostport.NormalizeHost(domainToCheck, false)
normalizedDomain, err := hostport.NormalizeHost(domainToCheck, false)

if err != nil {
return fmt.Errorf("%v: %v: incorrect ACL entry: %v", svc, glob, err)
} else if normalizedDomain != domainToCheck {
// There was no error but the config contains a non-normalized form
if strings.HasPrefix(glob, "*.") {
// (Re-add) wildcard if one was provided (for the error message)
normalizedDomain = "*." + normalizedDomain
}
return fmt.Errorf("%v: %v: incorrect ACL entry; use %q", svc, glob, normalizedDomain)
if err != nil {
return fmt.Errorf("%v: %v: incorrect ACL entry: %v", svc, glob, err)
// There was no error but the config contains a non-normalized form
} else if normalizedDomain != domainToCheck {
if strings.HasPrefix(glob, "*.") {
// (Re-add) wildcard if one was provided (for the error message)
normalizedDomain = "*." + normalizedDomain
}
return fmt.Errorf("%v: %v: incorrect ACL entry; use %q", svc, glob, normalizedDomain)
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/smokescreen/acl/v1/acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func TestMitmComfig(t *testing.T) {
d, err := acl.Decide(mitmService, "example-mitm.com", "")
a.NoError(err)
a.Equal(Allow, d.Result)
a.Equal("host matched allowed domain in rule", d.Reason)
a.Equal("host matched allowed domain in MITM rule", d.Reason)

a.NotNil(d.MitmConfig)
a.Equal(true, d.MitmConfig.DetailedHttpLogs)
Expand Down
27 changes: 11 additions & 16 deletions pkg/smokescreen/acl/v1/yaml_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,7 @@ func (cfg *YAMLConfig) Load() (*ACL, error) {
var allowedHostsMitm []MitmDomain

for _, w := range v.AllowedHostsMitm {
mitmDomain := MitmDomain{
MitmConfig: MitmConfig{
AddHeaders: w.AddHeaders,
DetailedHttpLogs: w.DetailedHttpLogs,
DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders,
},
Domain: w.Domain,
}
mitmDomain := NewMITMDomain(w)
allowedHostsMitm = append(allowedHostsMitm, mitmDomain)
}

Expand All @@ -123,14 +116,7 @@ func (cfg *YAMLConfig) Load() (*ACL, error) {
var allowedHostsMitm []MitmDomain

for _, w := range cfg.Default.AllowedHostsMitm {
mitmDomain := MitmDomain{
MitmConfig: MitmConfig{
AddHeaders: w.AddHeaders,
DetailedHttpLogs: w.DetailedHttpLogs,
DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders,
},
Domain: w.Domain,
}
mitmDomain := NewMITMDomain(w)
allowedHostsMitm = append(allowedHostsMitm, mitmDomain)
}

Expand All @@ -155,3 +141,12 @@ func (cfg *YAMLConfig) Load() (*ACL, error) {

return &acl, nil
}

func NewMITMDomain(w YAMLMitmRule) MitmDomain {
return MitmDomain{
AddHeaders: w.AddHeaders,
DetailedHttpLogs: w.DetailedHttpLogs,
DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders,
Domain: w.Domain,
}
}
8 changes: 6 additions & 2 deletions pkg/smokescreen/config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,14 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
}
mitmCa, err := tls.LoadX509KeyPair(yc.MitmCaCertFile, yc.MitmCaKeyFile)
if err != nil {
return fmt.Errorf("could not load mitmCa: %v", err)
return fmt.Errorf("mitm_ca_key_file error tls.LoadX509KeyPair: %w", err)
}
// set the leaf certificat to reduce per-handshake processing
if len(mitmCa.Certificate) == 0 {
return errors.New("mitm_ca_key_file error: mitm_ca_key_file contains no certificates")
}
if mitmCa.Leaf, err = x509.ParseCertificate(mitmCa.Certificate[0]); err != nil {
return fmt.Errorf("could not populate x509 Leaf value: %v", err)
return fmt.Errorf("could not populate x509 Leaf value: %w", err)
}
c.MitmTLSConfig = goproxy.TLSConfigFromCA(&mitmCa)
}
Expand Down
64 changes: 29 additions & 35 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,7 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout)
}
connTime := time.Since(start)

fields := logrus.Fields{
LogFieldConnEstablishMS: connTime.Milliseconds(),
}
sctx.logger = sctx.logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime))

if sctx.cfg.TimeConnect {
sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.requestedHost}, 1)
Expand All @@ -307,6 +304,22 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, true)

// Only wrap CONNECT conns with an InstrumentedConn. Connections used for traditional HTTP proxy
// requests are pooled and reused by net.Transport.
if sctx.proxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType)
pctx.ConnErrorHandler = ic.Error
conn = ic
} else {
conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout)
}

return conn, nil
}
func dialContextLoggerFields(pctx *goproxy.ProxyCtx, sctx *SmokescreenContext, conn net.Conn, connTime time.Duration) logrus.Fields {
fields := logrus.Fields{
LogFieldConnEstablishMS: connTime.Milliseconds(),
}
if conn != nil {
if addr := conn.LocalAddr(); addr != nil {
fields[LogFieldOutLocalAddr] = addr.String()
Expand All @@ -316,30 +329,14 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
fields[LogFieldOutRemoteAddr] = addr.String()
}
}
sctx.logger = sctx.logger.WithFields(fields)

// Only wrap CONNECT conns and MITM http conns with an InstrumentedConn. Connections used for traditional HTTP proxy
// requests are pooled and reused by net.Transport.
if sctx.proxyType == connectProxy || pctx.ConnectAction == goproxy.ConnectMitm {
// If we have a MITM and option is enabled, we can add detailed Request log fields
if pctx.ConnectAction == goproxy.ConnectMitm && sctx.Decision.MitmConfig != nil && sctx.Decision.MitmConfig.DetailedHttpLogs {
fields := logrus.Fields{
LogMitmReqUrl: pctx.Req.URL.String(),
LogMitmReqMethod: pctx.Req.Method,
LogMitmReqHeaders: redactHeaders(pctx.Req.Header, sctx.Decision.MitmConfig.DetailedHttpLogsFullHeaders),
}

sctx.logger = sctx.logger.WithFields(fields)

}
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType)
pctx.ConnErrorHandler = ic.Error
conn = ic
} else {
conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout)
// If we have a MITM and option is enabled, we can add detailed Request log fields
if pctx.ConnectAction == goproxy.ConnectMitm && sctx.Decision.MitmConfig != nil && sctx.Decision.MitmConfig.DetailedHttpLogs {
fields[LogMitmReqUrl] = pctx.Req.URL.String()
fields[LogMitmReqMethod] = pctx.Req.Method
fields[LogMitmReqHeaders] = redactHeaders(pctx.Req.Header, sctx.Decision.MitmConfig.DetailedHttpLogsFullHeaders)
}

return conn, nil
return fields
}

// HTTPErrorHandler allows returning a custom error response when smokescreen
Expand Down Expand Up @@ -468,12 +465,14 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
}
}

// Handle traditional HTTP proxy
// Handle traditional HTTP proxy and MITM outgoing requests (smokescreen - remote )
proxy.OnRequest().DoFunc(func(req *http.Request, pctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
// Set this on every request as every request mints a new goproxy.ProxyCtx
pctx.RoundTripper = rtFn

// For MITM requests intended for the remote host, the sole requirement was to configure the RoundTripper
// In the context of MITM request. Once the originating request (client - smokescreen) has been allowed
// goproxy/https.go calls proxy.filterRequest on the outgoing request (smokescreen - remote host) which calls this function
// in this case we ony want to configure the RoundTripper
if pctx.ConnectAction == goproxy.ConnectMitm {
return req, nil
}
Expand Down Expand Up @@ -571,13 +570,8 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
return rejectResponse(pctx, pctx.Error)
}

if pctx.ConnectAction == goproxy.ConnectMitm {
// If the connection is a MITM
// 1 we don't want to log as it will be done in HandleConnectFunc
// 2 we want to close idle connections as they are not closed by default
// and CANONICAL-PROXY-CN-CLOSE is called on InstrumentedConn.Close
proxy.Tr.CloseIdleConnections()
} else {
// We don't want to log if the connection is a MITM as it will be done in HandleConnectFunc
if pctx.ConnectAction != goproxy.ConnectMitm {
// In case of an error, this function is called a second time to filter the
// response we generate so this logger will be called once.
logProxy(pctx)
Expand Down
6 changes: 4 additions & 2 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1432,9 +1432,10 @@ func TestMitm(t *testing.T) {
r.NoError(err)
cfg.Listener = l

proxy := proxyServer(cfg)
proxy := BuildProxy(cfg)
httpProxy := httptest.NewServer(proxy)
remote := httptest.NewTLSServer(h)
client, err := proxyClient(proxy.URL)
client, err := proxyClient(httpProxy.URL)
r.NoError(err)

req, err := http.NewRequest("GET", remote.URL, nil)
Expand Down Expand Up @@ -1480,6 +1481,7 @@ func TestMitm(t *testing.T) {
r.NotNil(proxyDecision)
r.Contains(proxyDecision.Data, "proxy_type")
r.Equal("connect", proxyDecision.Data["proxy_type"])
proxy.Tr.CloseIdleConnections()
// check proxyclose log entry has information about the request headers
proxyClose := findCanonicalProxyClose(logHook.AllEntries())
r.NotNil(proxyClose)
Expand Down

0 comments on commit 85fea80

Please sign in to comment.