glider/rule/check.go

141 lines
2.8 KiB
Go

package rule
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strings"
"time"
"github.com/nadoo/glider/pkg/pool"
)
// Checker is a forwarder health checker.
type Checker interface {
Check(fwdr *Forwarder) (elap time.Duration, err error)
}
type tcpChecker struct {
addr string
timeout time.Duration
}
func newTcpChecker(addr string, timeout time.Duration) *tcpChecker {
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "80")
}
return &tcpChecker{addr, timeout}
}
// Check implements the Checker interface.
func (c *tcpChecker) Check(fwdr *Forwarder) (time.Duration, error) {
startTime := time.Now()
rc, err := fwdr.Dial("tcp", c.addr)
if err != nil {
return 0, err
}
rc.Close()
return time.Since(startTime), nil
}
type httpChecker struct {
addr string
uri string
expect string
timeout time.Duration
tlsConfig *tls.Config
serverName string
regex *regexp.Regexp
}
func newHttpChecker(addr, uri, expect string, timeout time.Duration, withTLS bool) *httpChecker {
c := &httpChecker{
addr: addr,
uri: uri,
expect: expect,
timeout: timeout,
regex: regexp.MustCompile(expect),
}
if _, p, _ := net.SplitHostPort(addr); p == "" {
if withTLS {
c.addr = net.JoinHostPort(addr, "443")
} else {
c.addr = net.JoinHostPort(addr, "80")
}
}
c.serverName = c.addr[:strings.LastIndex(c.addr, ":")]
if withTLS {
c.tlsConfig = &tls.Config{ServerName: c.serverName}
}
return c
}
// Check implements the Checker interface.
func (c *httpChecker) Check(fwdr *Forwarder) (time.Duration, error) {
startTime := time.Now()
rc, err := fwdr.Dial("tcp", c.addr)
if err != nil {
return 0, err
}
if c.tlsConfig != nil {
tlsConn := tls.Client(rc, c.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return 0, err
}
rc = tlsConn
}
defer rc.Close()
if c.timeout > 0 {
rc.SetDeadline(time.Now().Add(c.timeout))
}
if _, err = io.WriteString(rc,
"GET "+c.uri+" HTTP/1.1\r\nHost:"+c.serverName+"\r\n\r\n"); err != nil {
return 0, err
}
r := pool.GetBufReader(rc)
defer pool.PutBufReader(r)
line, err := r.ReadString('\n')
if err != nil {
return 0, err
}
if !c.regex.MatchString(line) {
return 0, fmt.Errorf("expect: %s, got: %s", c.expect, line)
}
elapsed := time.Since(startTime)
if elapsed > c.timeout {
return elapsed, errors.New("timeout")
}
return elapsed, nil
}
type fileChecker struct{ path string }
func newFileChecker(path string) *fileChecker { return &fileChecker{path} }
// Check implements the Checker interface.
func (c *fileChecker) Check(fwdr *Forwarder) (time.Duration, error) {
cmd := exec.Command(c.path)
cmd.Stdout = os.Stdout
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, "FORWARDER_ADDR="+fwdr.Addr())
cmd.Env = append(cmd.Env, "FORWARDER_URL="+fwdr.URL())
return 0, cmd.Run()
}