diff --git a/main.go b/main.go index c36f96b..964454f 100644 --- a/main.go +++ b/main.go @@ -1,119 +1,181 @@ package main import ( - "errors" - "flag" - "fmt" - "io" - "log" - "net" - "strings" - "time" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "strings" + "sync" + "time" ) var ( - masterAddr *net.TCPAddr - raddr *net.TCPAddr - saddr *net.TCPAddr + saddr *net.TCPAddr // Address of the sentinel service + slock *sync.Mutex // Guard for above var - localAddr = flag.String("listen", ":9999", "local address") - sentinelAddr = flag.String("sentinel", ":26379", "remote address") - masterName = flag.String("master", "", "name of the master redis node") + masterAddr *net.TCPAddr // Address of the redis master + mlock *sync.Mutex // Guard for above var + + localAddr = flag.String("listen", ":9999", "local address") + sentinelAddr = flag.String("sentinel", ":26379", "remote address") + masterName = flag.String("master", "", "name of the master redis node") ) func main() { - flag.Parse() + flag.Parse() - laddr, err := net.ResolveTCPAddr("tcp", *localAddr) - if err != nil { - log.Fatal("Failed to resolve local address: %s", err) - } - saddr, err = net.ResolveTCPAddr("tcp", *sentinelAddr) - if err != nil { - log.Fatal("Failed to resolve sentinel address: %s", err) - } + slock = &sync.Mutex{} + mlock = &sync.Mutex{} - go master() + laddr, err := net.ResolveTCPAddr("tcp", *localAddr) + if err != nil { + log.Fatalf("Failed to resolve local address: %s", err.Error()) + } + resolveSentinel(*sentinelAddr) - listener, err := net.ListenTCP("tcp", laddr) - if err != nil { - log.Fatal(err) - } + // If sentinel's address is set to nil, this goroutine will resolve it and set the var again + go sentinelUpdater(*sentinelAddr) - for { - conn, err := listener.AcceptTCP() - if err != nil { - log.Println(err) - continue - } + // Continuously query sentinel for the master address, updating masterAddr when needed + go master() - go proxy(conn, masterAddr) - } + listener, err := net.ListenTCP("tcp", laddr) + if err != nil { + log.Fatal(err) + } + + for { + conn, err := listener.AcceptTCP() + if err != nil { + log.Println(err) + continue + } + + go proxy(conn, masterAddr) + } +} + +func sentinelUpdater(sentinelAddr string) { + // Resolve the address of sentinel when needed + for { + if saddr == nil { + log.Print("Resolving sentinel address") + resolveSentinel(sentinelAddr) + } + time.Sleep(1 * time.Second) + } +} + +func resolveSentinel(sentinelAddr string) { + // var err error + addr, err := net.ResolveTCPAddr("tcp", sentinelAddr) + if err != nil { + log.Printf("Failed to resolve sentinel address: %s", err.Error()) + return + } + slock.Lock() + saddr = addr + slock.Unlock() + // TODO other cases when saddr isn't valid } func master() { - var err error - for { - masterAddr, err = getMasterAddr(saddr, *masterName) - if err != nil { - log.Println(err) - } - time.Sleep(1 * time.Second) - } + var err error + var tempAddr *net.TCPAddr + for { + tempAddr, err = getMasterAddr() + if err != nil { + log.Printf("Failed to get master addres: %s", err.Error()) + } else { + mlock.Lock() + masterAddr = tempAddr + mlock.Unlock() + } + time.Sleep(1 * time.Second) + } } func pipe(r io.Reader, w io.WriteCloser) { - io.Copy(w, r) - w.Close() + io.Copy(w, r) + w.Close() } func proxy(local io.ReadWriteCloser, remoteAddr *net.TCPAddr) { - remote, err := net.DialTCP("tcp", nil, remoteAddr) - if err != nil { - log.Println(err) - local.Close() - return - } - go pipe(local, remote) - go pipe(remote, local) + remote, err := net.DialTCP("tcp", nil, remoteAddr) + if err != nil { + log.Println(err) + local.Close() + return + } + go pipe(local, remote) + go pipe(remote, local) } -func getMasterAddr(sentinelAddress *net.TCPAddr, masterName string) (*net.TCPAddr, error) { - conn, err := net.DialTCP("tcp", nil, sentinelAddress) - if err != nil { - return nil, err - } +// Connect to Sentinel and query it to find the redis master +func getMasterAddr() (*net.TCPAddr, error) { + // Connect to sentinel + // If the connection times out, that master is probably gone. + // Mark saddr as nil so that the resolver thread will update it later. + // Create a local copy of the sentinel address, it can change under our feet + slock.Lock() + if saddr == nil { + defer slock.Unlock() + return nil, errors.New("Sentinel address not available") + } + local_saddr := *saddr + slock.Unlock() - defer conn.Close() + sentConn, err := dialTimeout(&local_saddr, 5*time.Second) + if err != nil { + log.Printf("Connecting to sentinel master timed out/failed: %s\n", err.Error()) + slock.Lock() + saddr = nil + slock.Unlock() + return nil, err + } + defer sentConn.Close() - conn.Write([]byte(fmt.Sprintf("sentinel get-master-addr-by-name %s\n", masterName))) + // We connected to the master, ask for the redis master + sentConn.Write([]byte(fmt.Sprintf("sentinel get-master-addr-by-name %s\n", *masterName))) - b := make([]byte, 256) - _, err = conn.Read(b) - if err != nil { - log.Fatal(err) - } + b := make([]byte, 256) + _, err = sentConn.Read(b) + if err != nil { + return nil, err + } + parts := strings.Split(string(b), "\r\n") + if len(parts) < 5 { + err = errors.New("Couldn't get master address from sentinel") + return nil, err + } - parts := strings.Split(string(b), "\r\n") + // Parse the address for the master node + stringaddr := fmt.Sprintf("%s:%s", parts[2], parts[4]) + addr, err := net.ResolveTCPAddr("tcp", stringaddr) + if err != nil { + return nil, err + } - if len(parts) < 5 { - err = errors.New("Couldn't get master address from sentinel") - return nil, err - } - - //getting the string address for the master node - stringaddr := fmt.Sprintf("%s:%s", parts[2], parts[4]) - addr, err := net.ResolveTCPAddr("tcp", stringaddr) - - if err != nil { - return nil, err - } - - //check that there's actually someone listening on that address - conn2, err := net.DialTCP("tcp", nil, addr) - if err == nil { - defer conn2.Close() - } - - return addr, err + // Verify the returned address is actually listening + // TODO is this really needed? + conn2, err := dialTimeout(addr, 5*time.Second) + if err != nil { + return nil, err + } + defer conn2.Close() + return addr, err +} + +// Connect to a TCPAddr, failing if a timeout is exceeded or other error encountered +func dialTimeout(destAddr *net.TCPAddr, timeout time.Duration) (*net.TCPConn, error) { + d := net.Dialer{Timeout: timeout} + netcon, err := d.Dial("tcp", fmt.Sprintf("%s:%d", destAddr.IP, destAddr.Port)) + if err != nil { + return nil, err + } + conn, _ := netcon.(*net.TCPConn) + return conn, nil }