//go:build integration
// +build integration

package rtnetlink

import (
	"testing"

	"github.com/cilium/ebpf"
	"github.com/cilium/ebpf/asm"
	"github.com/cilium/ebpf/rlimit"
	"github.com/jsimonetti/rtnetlink/v2/internal/testutils"
	"github.com/jsimonetti/rtnetlink/v2/internal/unix"
	"github.com/mdlayher/netlink"
)

// lo accesses the loopback interface present in every network namespace.
var lo uint32 = 1

func xdpPrograms(tb testing.TB) (int32, int32) {
	tb.Helper()

	// Load XDP_PASS into the return value register.
	bpfProgram := &ebpf.ProgramSpec{
		Type: ebpf.XDP,
		Instructions: asm.Instructions{
			asm.LoadImm(asm.R0, int64(2), asm.DWord),
			asm.Return(),
		},
		License: "MIT",
	}
	prog1, err := ebpf.NewProgram(bpfProgram)
	if err != nil {
		tb.Fatal(err)
	}

	prog2, err := ebpf.NewProgram(bpfProgram)
	if err != nil {
		tb.Fatal(err)
	}

	tb.Cleanup(func() {
		prog1.Close()
		prog2.Close()
	})

	// Use the file descriptor of the programs
	return int32(prog1.FD()), int32(prog2.FD())
}

func attachXDP(tb testing.TB, conn *Conn, ifIndex uint32, xdp *LinkXDP) {
	tb.Helper()

	message := LinkMessage{
		Family: unix.AF_UNSPEC,
		Index:  ifIndex,
		Attributes: &LinkAttributes{
			XDP: xdp,
		},
	}

	if err := conn.Link.Set(&message); err != nil {
		tb.Fatalf("attaching program with fd %d to link at ifindex %d: %s", xdp.FD, ifIndex, err)
	}
}

// getXDP returns the XDP attach, XDP prog ID and errors when the
// interface could not be fetched
func getXDP(tb testing.TB, conn *Conn, ifIndex uint32) (uint8, uint32) {
	tb.Helper()

	interf, err := conn.Link.Get(ifIndex)
	if err != nil {
		tb.Fatalf("getting link xdp properties: %s", err)
	}

	return interf.Attributes.XDP.Attached, interf.Attributes.XDP.ProgID
}

func TestLinkXDPAttach(t *testing.T) {
	if err := rlimit.RemoveMemlock(); err != nil {
		t.Fatal(err)
	}

	conn, err := Dial(&netlink.Config{NetNS: testutils.NetNS(t)})
	if err != nil {
		t.Fatalf("failed to establish netlink socket: %v", err)
	}
	defer conn.Close()

	progFD1, progFD2 := xdpPrograms(t)

	tests := []struct {
		name string
		xdp  *LinkXDP
	}{
		{
			name: "with FD, no expected FD",
			xdp: &LinkXDP{
				FD:    progFD1,
				Flags: unix.XDP_FLAGS_SKB_MODE,
			},
		},
		{
			name: "with FD, expected FD == FD",
			xdp: &LinkXDP{
				FD:         progFD1,
				ExpectedFD: progFD1,
				Flags:      unix.XDP_FLAGS_SKB_MODE,
			},
		},
		{
			name: "with FD, expected FD != FD",
			xdp: &LinkXDP{
				FD:         progFD1,
				ExpectedFD: progFD2,
				Flags:      unix.XDP_FLAGS_SKB_MODE,
			},
		},
		{
			name: "with FD, expected FD < 0",
			xdp: &LinkXDP{
				FD:         progFD1,
				ExpectedFD: -1,
				Flags:      unix.XDP_FLAGS_SKB_MODE,
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			attachXDP(t, conn, lo, tt.xdp)

			attached, progID := getXDP(t, conn, lo)
			if attached != unix.XDP_FLAGS_SKB_MODE {
				t.Fatalf("XDP attached state does not match. Got: %d, wanted: %d", attached, unix.XDP_FLAGS_SKB_MODE)
			}
			if attached == unix.XDP_FLAGS_SKB_MODE && progID == 0 {
				t.Fatalf("XDP program should be attached but program ID is 0")
			}
		})
	}
}

func TestLinkXDPClear(t *testing.T) {
	if err := rlimit.RemoveMemlock(); err != nil {
		t.Fatal(err)
	}

	conn, err := Dial(&netlink.Config{NetNS: testutils.NetNS(t)})
	if err != nil {
		t.Fatalf("failed to establish netlink socket: %v", err)
	}
	defer conn.Close()

	progFD1, _ := xdpPrograms(t)

	attachXDP(t, conn, lo, &LinkXDP{
		FD:    progFD1,
		Flags: unix.XDP_FLAGS_SKB_MODE,
	})

	// clear the BPF program from the link
	attachXDP(t, conn, lo, &LinkXDP{
		FD:    -1,
		Flags: unix.XDP_FLAGS_SKB_MODE,
	})

	attached, progID := getXDP(t, conn, lo)
	if progID != 0 {
		t.Fatalf("there is still a program loaded, while we cleared the link")
	}
	if attached != 0 {
		t.Fatalf(
			"XDP attached state does not match. Got: %d, wanted: %d\nThere should be no program loaded",
			attached, 0,
		)
	}
}

func TestLinkXDPReplace(t *testing.T) {
	// As of kernel version 5.7, the use of EXPECTED_FD and XDP_FLAGS_REPLACE
	// is supported. We check here if the test host kernel fills this
	// requirement. If the requirement is not met, we skip this test and
	// output a notice. Running the code on a kernel version lower then 5.7
	// will throw an "invalid argument" error.
	// source kernel 5.6:
	// https://elixir.bootlin.com/linux/v5.6/source/net/core/dev.c#L8662
	// source kernel 5.7:
	// https://elixir.bootlin.com/linux/v5.7/source/net/core/dev.c#L8674
	testutils.SkipOnOldKernel(t, "5.7", "XDP_FLAGS_REPLACE")

	if err := rlimit.RemoveMemlock(); err != nil {
		t.Fatal(err)
	}

	conn, err := Dial(&netlink.Config{NetNS: testutils.NetNS(t)})
	if err != nil {
		t.Fatalf("failed to establish netlink socket: %v", err)
	}
	defer conn.Close()

	progFD1, progFD2 := xdpPrograms(t)

	attachXDP(t, conn, lo, &LinkXDP{
		FD:    progFD1,
		Flags: unix.XDP_FLAGS_SKB_MODE,
	})

	_, progID1 := getXDP(t, conn, lo)

	if err := conn.Link.Set(&LinkMessage{
		Family: unix.AF_UNSPEC,
		Index:  lo,
		Attributes: &LinkAttributes{
			XDP: &LinkXDP{
				FD:         progFD2,
				ExpectedFD: progFD2,
				Flags:      unix.XDP_FLAGS_SKB_MODE | unix.XDP_FLAGS_REPLACE,
			},
		},
	}); err == nil {
		t.Fatalf("replaced XDP program while expected FD did not match: %v", err)
	}

	_, progID2 := getXDP(t, conn, lo)
	if progID2 != progID1 {
		t.Fatal("XDP prog ID does not match previous program ID, which it should")
	}

	attachXDP(t, conn, lo, &LinkXDP{
		FD:         progFD2,
		ExpectedFD: progFD1,
		Flags:      unix.XDP_FLAGS_SKB_MODE | unix.XDP_FLAGS_REPLACE,
	})

	_, progID2 = getXDP(t, conn, lo)
	if progID2 == progID1 {
		t.Fatal("XDP prog ID does match previous program ID, which it shouldn't")
	}
}

func TestLinkListByKind(t *testing.T) {
	if err := rlimit.RemoveMemlock(); err != nil {
		t.Fatal(err)
	}

	conn, err := Dial(&netlink.Config{NetNS: testutils.NetNS(t)})
	if err != nil {
		t.Fatalf("failed to establish netlink socket: %v", err)
	}
	defer conn.Close()

	links, err := conn.Link.ListByKind("SomeImpossibleLinkKind")
	if err != nil {
		t.Fatalf("LinkListByKind() finished with error: %v", err)
	}

	if len(links) > 0 {
		t.Fatalf("LinkListByKind() found %d links with impossible kind", len(links))
	}
}

func TestLinkSetMaster(t *testing.T) {
	ns := testutils.NetNS(t)
	conn, err := Dial(&netlink.Config{NetNS: ns})
	if err != nil {
		t.Fatalf("failed to dial: %v", err)
	}
	defer conn.Close()

	const (
		bridgeIndex = 2001
		vethIndex   = 2002
	)

	// Create a bridge
	err = conn.Link.New(&LinkMessage{
		Index: bridgeIndex,
		Attributes: &LinkAttributes{
			Name: "testbr0",
			Info: &LinkInfo{
				Kind: "bridge",
			},
		},
	})
	if err != nil {
		t.Fatalf("failed to create bridge: %v", err)
	}
	defer conn.Link.Delete(bridgeIndex)

	// Create a dummy interface
	err = conn.Link.New(&LinkMessage{
		Index: vethIndex,
		Attributes: &LinkAttributes{
			Name: "testdum0",
			Info: &LinkInfo{
				Kind: "dummy",
			},
		},
	})
	if err != nil {
		t.Fatalf("failed to create dummy interface: %v", err)
	}
	defer conn.Link.Delete(vethIndex)

	// Enslave the dummy to the bridge
	err = conn.Link.SetMaster(vethIndex, bridgeIndex, nil)
	if err != nil {
		t.Fatalf("failed to set master: %v", err)
	}

	// Verify it's enslaved
	got, err := conn.Link.Get(vethIndex)
	if err != nil {
		t.Fatalf("failed to get link: %v", err)
	}

	if got.Attributes.Master == nil {
		t.Fatal("expected Master to be set, got nil")
	}
	if *got.Attributes.Master != bridgeIndex {
		t.Fatalf("unexpected master index:\n got: %d\nwant: %d", *got.Attributes.Master, bridgeIndex)
	}
}

func TestLinkRemoveMaster(t *testing.T) {
	ns := testutils.NetNS(t)
	conn, err := Dial(&netlink.Config{NetNS: ns})
	if err != nil {
		t.Fatalf("failed to dial: %v", err)
	}
	defer conn.Close()

	const (
		bridgeIndex = 2101
		dummyIndex  = 2102
	)

	// Create a bridge
	err = conn.Link.New(&LinkMessage{
		Index: bridgeIndex,
		Attributes: &LinkAttributes{
			Name: "testbr1",
			Info: &LinkInfo{
				Kind: "bridge",
			},
		},
	})
	if err != nil {
		t.Fatalf("failed to create bridge: %v", err)
	}
	defer conn.Link.Delete(bridgeIndex)

	// Create a dummy interface
	err = conn.Link.New(&LinkMessage{
		Index: dummyIndex,
		Attributes: &LinkAttributes{
			Name: "testdum1",
			Info: &LinkInfo{
				Kind: "dummy",
			},
		},
	})
	if err != nil {
		t.Fatalf("failed to create dummy interface: %v", err)
	}
	defer conn.Link.Delete(dummyIndex)

	// Enslave it to the bridge
	err = conn.Link.SetMaster(dummyIndex, bridgeIndex, nil)
	if err != nil {
		t.Fatalf("failed to set master: %v", err)
	}

	// Verify it's enslaved
	got, err := conn.Link.Get(dummyIndex)
	if err != nil {
		t.Fatalf("failed to get link: %v", err)
	}
	if got.Attributes.Master == nil || *got.Attributes.Master != bridgeIndex {
		t.Fatal("interface was not enslaved")
	}

	// Remove from master
	err = conn.Link.RemoveMaster(dummyIndex)
	if err != nil {
		t.Fatalf("failed to remove master: %v", err)
	}

	// Verify it's un-enslaved
	got, err = conn.Link.Get(dummyIndex)
	if err != nil {
		t.Fatalf("failed to get link: %v", err)
	}

	// Master should be nil or 0
	if got.Attributes.Master != nil && *got.Attributes.Master != 0 {
		t.Fatalf("expected Master to be 0 or nil, got %d", *got.Attributes.Master)
	}
}
