Skip to content
Snippets Groups Projects
Commit 70ee5660 authored by cyberta's avatar cyberta Committed by cyberta
Browse files

fix index out of range panic in case doh resolution returns empty A record...

fix index out of range panic in case doh resolution returns empty A record slice. While being at it add IPv6 resolution to DoH implementation
parent 77eafa57
Branches main
No related tags found
1 merge request!53fix index out of range panic in case DoH resolution returns empty A record
Pipeline #282543 passed
...@@ -36,17 +36,48 @@ func dohQuery(domain string) (string, error) { ...@@ -36,17 +36,48 @@ func dohQuery(domain string) (string, error) {
HTTPClient: &http.Client{Timeout: 10 * time.Second}, HTTPClient: &http.Client{Timeout: 10 * time.Second},
} }
// lookup A records for IPv4
ips, _, err := resolver.LookupA(domain) ips, _, err := resolver.LookupA(domain)
if err != nil { if err != nil {
log.Warn(). log.Warn().
Str("resolver", dnsServer). Str("resolver", dnsServer).
Str("domain", domain). Str("domain", domain).
Err(err). Err(err).
Msg("Could not resolve host with DNS over HTTPs") Msg("Could not resolve host's IPv4 address with DNS over HTTPS")
continue continue
} }
if len(ips) > 0 {
return ips[0].IP4, nil return ips[0].IP4, nil
} }
// fallback: lookup AAAA records for IPv6
log.Warn().
Str("resolver", dnsServer).
Str("domain", domain).
Err(err).
Msg("No A records found for domain")
v6Ips, _, err := resolver.LookupAAAA(domain)
if err != nil {
log.Warn().
Str("resolver", dnsServer).
Str("domain", domain).
Err(err).
Msg("Could not resolve host's IPv6 address with DNS over HTTPS")
continue
}
if len(v6Ips) > 0 {
return v6Ips[0].IP6, nil
}
log.Warn().
Str("resolver", dnsServer).
Str("domain", domain).
Err(err).
Msg("No AAAA records found for domain")
}
return "", errors.New("Could not resolve ip with DNS over HTTPS. Tried all resolvers") return "", errors.New("Could not resolve ip with DNS over HTTPS. Tried all resolvers")
} }
...@@ -2,6 +2,7 @@ package bootstrap ...@@ -2,6 +2,7 @@ package bootstrap
import ( import (
"os" "os"
"strings"
"testing" "testing"
"github.com/rs/zerolog" "github.com/rs/zerolog"
...@@ -18,4 +19,19 @@ func TestDoh(t *testing.T) { ...@@ -18,4 +19,19 @@ func TestDoh(t *testing.T) {
ip, err := dohQuery("leap.se") ip, err := dohQuery("leap.se")
assert.NoError(t, err, "dohQuery failed") assert.NoError(t, err, "dohQuery failed")
assert.NotNil(t, ip, "ip should not be nil") assert.NotNil(t, ip, "ip should not be nil")
assert.NotEmpty(t, ip, "ip should not be empty")
}
func TestDohHandleEmptyRecords(t *testing.T) {
ip, err := dohQuery("notexising-kjhfdfghfhjgiuzuzfgfcdxfsa.org")
assert.Error(t, err, "dohQuery failed")
assert.Empty(t, ip, "ip should be empty")
}
func TestDohHandleAAAARecords(t *testing.T) {
ip, err := dohQuery("ipv6.google.com")
assert.NoError(t, err, "dohQuery failed")
assert.NotNil(t, ip, "ip should not be nil")
assert.NotEmpty(t, ip, "ip should not be empty")
assert.Equal(t, true, strings.Contains(ip, ":") && !strings.Contains(ip, "."))
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
bitmask_storage "0xacab.org/leap/bitmask-core/pkg/storage" bitmask_storage "0xacab.org/leap/bitmask-core/pkg/storage"
...@@ -57,16 +58,22 @@ func (c *Config) getAPIClient() *http.Client { ...@@ -57,16 +58,22 @@ func (c *Config) getAPIClient() *http.Client {
Str("domain", addr). Str("domain", addr).
Msg("Resolving host with DNS over HTTPs") Msg("Resolving host with DNS over HTTPs")
ip4, err := dohQuery(c.Host) ip, err := dohQuery(c.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debug(). log.Debug().
Str("domain", addr). Str("domain", addr).
Str("ip4", ip4). Str("ip", ip).
Msg("Sucessfully resolved host via DNS over HTTPs") Msg("Sucessfully resolved host via DNS over HTTPs")
addr = fmt.Sprintf("%s:%d", ip4, c.Port) if strings.Contains(ip, ":") {
// IPv6 address requires extra brackets in order to
// distinguish address from port
addr = fmt.Sprintf("[%s]:%d", ip, c.Port)
} else {
addr = fmt.Sprintf("%s:%d", ip, c.Port)
}
} }
roller, err := utls.NewRoller() roller, err := utls.NewRoller()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment