diff --git a/lib/ljsyscall/syscall/linux/syscalls.lua b/lib/ljsyscall/syscall/linux/syscalls.lua index fb6da2ea06..843e9e713e 100644 --- a/lib/ljsyscall/syscall/linux/syscalls.lua +++ b/lib/ljsyscall/syscall/linux/syscalls.lua @@ -483,33 +483,30 @@ local function get_maxnumnodes() return math.floor(((#line+1)/9)*32) end end + -- If we don't know, guess that the system has a max of 1024 nodes. + return 1024 +end + +local function ensure_bitmask(mask, size) + if ffi.istype(t.bitmask, mask) then return mask end + return t.bitmask(mask, size or get_maxnumnodes()) end function S.get_mempolicy(mode, mask, addr, flags) mode = mode or t.int1() - local size - if ffi.istype(t.bitmask, mask) then - -- if mask was provided by the caller, then use its size - -- and let the syscall error if it's too small - size = ffi.cast("uint64_t", tonumber(mask.size)) - else - local mask_for_size = t.bitmask(mask) - -- Size should be at least equals to maxnumnodes. - size = ffi.cast("uint64_t", math.max(tonumber(mask_for_size.size), get_maxnumnodes())) - mask = t.bitmask(mask, tonumber(size)) - end - local ret, err = C.get_mempolicy(mode, mask.mask, size, addr or 0, c.MPOL_FLAG[flags]) + mask = ensure_bitmask(mask); + local ret, err = C.get_mempolicy(mode, mask.mask, mask.size, addr or 0, c.MPOL_FLAG[flags]) if ret == -1 then return nil, t.error(err or errno()) end return { mode=mode[0], mask=mask } end function S.set_mempolicy(mode, mask) - mask = mktype(t.bitmask, mask) + mask = ensure_bitmask(mask); return retbool(C.set_mempolicy(c.MPOL_MODE[mode], mask.mask, mask.size)) end function S.migrate_pages(pid, from, to) - from = mktype(t.bitmask, from) - to = mktype(t.bitmask, to) + from = ensure_bitmask(from); + to = ensure_bitmask(to, from.size) assert(from.size == to.size, "incompatible nodemask sizes") return retbool(C.migrate_pages(pid or 0, from.size, from.mask, to.mask)) end diff --git a/src/lib/numa.lua b/src/lib/numa.lua index 6c7f6518ba..a600db7033 100644 --- a/src/lib/numa.lua +++ b/src/lib/numa.lua @@ -97,57 +97,6 @@ function unbind_cpu () bound_cpu = nil end -local blacklisted_kernels = { - '>=4.15', -} -local function sys_kernel () - return lib.readfile('/proc/sys/kernel/osrelease', '*all'):gsub('%s$', '') -end -local function parse_version_number (str) - local t = {} - for each in str:gmatch("([^.]+)") do - table.insert(t, tonumber(each) or 0) - end - return t -end -local function equals (v1, v2) - for i, p1 in ipairs(v1) do - local p2 = v2[i] or 0 - if p1 ~= p2 then return false end - end - return true -end -local function greater_or_equals (v1, v2) - for i, p1 in ipairs(v1) do - local p2 = v2[i] or 0 - if p2 > p1 then return false end - end - return true -end -local function greater (v1, v2) - return greater_or_equals(v1, v2) and not equals(v1, v2) -end -function is_blacklisted_kernel (v) - for _, each in ipairs(blacklisted_kernels) do - -- Greater or equal. - if each:sub(1, 2) == '>=' then - each = each:sub(3, #each) - local v1, v2 = parse_version_number(v), parse_version_number(each) - if greater_or_equals(v1, v2) then return true end - -- Greater than. - elseif each:sub(1, 1) == '>' then - each = each:sub(2, #each) - local v1, v2 = parse_version_number(v), parse_version_number(each) - if greater(v1, v2) then return true end - -- Equals. - else - local v1, v2 = parse_version_number(v), parse_version_number(each) - if equals(v1, v2) then return true end - end - end - return false -end - function bind_to_cpu (cpu) local function contains (t, e) for k,v in ipairs(t) do @@ -177,11 +126,6 @@ function unbind_numa_node () end function bind_to_numa_node (node, policy) - local kernel = sys_kernel() - if is_blacklisted_kernel(kernel) then - print(("WARNING: Buggy kernel '%s'. Not binding CPU to NUMA node."):format(kernel)) - return - end if node == bound_numa_node then return end if not node then return unbind_numa_node() end assert(not bound_numa_node, "already bound") @@ -247,9 +191,5 @@ function selftest () test_pci_affinity(pciaddr) end - assert(greater(parse_version_number('4.15'), parse_version_number('4.4.80'))) - assert(greater_or_equals(parse_version_number('4.15'), parse_version_number('4.15'))) - assert(not greater(parse_version_number('4.14'), parse_version_number('4.15'))) - print('selftest: numa: ok') end