diff --git a/libcontainer/cgroups/utils.go b/libcontainer/cgroups/utils.go index a6e9b29a2..2a9e59f2c 100644 --- a/libcontainer/cgroups/utils.go +++ b/libcontainer/cgroups/utils.go @@ -409,14 +409,22 @@ func RemovePaths(paths map[string]string) (err error) { } func GetHugePageSize() ([]string, error) { - var pageSizes []string - sizeList := []string{"B", "KB", "MB", "GB", "TB", "PB"} files, err := ioutil.ReadDir("/sys/kernel/mm/hugepages") if err != nil { - return pageSizes, err + return []string{}, err } + var fileNames []string for _, st := range files { - nameArray := strings.Split(st.Name(), "-") + fileNames = append(fileNames, st.Name()) + } + return getHugePageSizeFromFilenames(fileNames) +} + +func getHugePageSizeFromFilenames(fileNames []string) ([]string, error) { + var pageSizes []string + sizeList := []string{"B", "KB", "MB", "GB", "TB", "PB"} + for _, fileName := range fileNames { + nameArray := strings.Split(fileName, "-") pageSize, err := units.RAMInBytes(nameArray[1]) if err != nil { return []string{}, err diff --git a/libcontainer/cgroups/utils_test.go b/libcontainer/cgroups/utils_test.go index fcca9d73a..3214b9de0 100644 --- a/libcontainer/cgroups/utils_test.go +++ b/libcontainer/cgroups/utils_test.go @@ -4,6 +4,7 @@ package cgroups import ( "bytes" + "errors" "fmt" "reflect" "strings" @@ -421,3 +422,38 @@ func TestFindCgroupMountpointAndRoot(t *testing.T) { } } } + +func TestGetHugePageSizeImpl(t *testing.T) { + + testCases := []struct { + inputFiles []string + outputPageSizes []string + err error + }{ + { + inputFiles: []string{"hugepages-1048576kB", "hugepages-2048kB", "hugepages-32768kB", "hugepages-64kB"}, + outputPageSizes: []string{"1GB", "2MB", "32MB", "64KB"}, + err: nil, + }, + { + inputFiles: []string{}, + outputPageSizes: []string{}, + err: nil, + }, + { + inputFiles: []string{"hugepages-a"}, + outputPageSizes: []string{}, + err: errors.New("invalid size: 'a'"), + }, + } + + for _, c := range testCases { + pageSizes, err := getHugePageSizeFromFilenames(c.inputFiles) + if len(pageSizes) != 0 && len(c.outputPageSizes) != 0 && !reflect.DeepEqual(pageSizes, c.outputPageSizes) { + t.Errorf("expected %s, got %s", c.outputPageSizes, pageSizes) + } + if err != nil && err.Error() != c.err.Error() { + t.Errorf("expected error %s, got %s", c.err, err) + } + } +}