From 8217cd2238c40cf77208aa27a7cc09879e685890 Mon Sep 17 00:00:00 2001
From: Yu Zhao <yuzhao@google.com>
Date: Mon, 5 Apr 2021 04:35:07 -0600
Subject: [PATCH 06/10] mm: multigenerational lru: aging

The aging produces young generations. Given an lruvec, the aging
traverses lruvec_memcg()->mm_list and calls walk_page_range() to scan
PTEs for accessed pages. Upon finding one, the aging updates its
generation number to max_seq (modulo MAX_NR_GENS). After each round of
traversal, the aging increments max_seq. The aging is due when
min_seq[] reaches max_seq-1.

The aging uses the following optimizations when walking page tables:
  1) It skips non-leaf PMD entries that have the accessed bit cleared
  when CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG=y.
  2) It does not zigzag between a PGD table and the same PMD or PTE
  table spanning multiple VMAs. In other words, it finishes all the
  VMAs within the range of the same PMD or PTE table before it returns
  to this PGD table. This optimizes workloads that have large numbers
  of tiny VMAs, especially when CONFIG_PGTABLE_LEVELS=5.

Signed-off-by: Yu Zhao <yuzhao@google.com>
Tested-by: Konstantin Kharlamov <Hi-Angel@yandex.ru>
Change-Id: I3ae8abc3100d023cecb3a699d86020ae6fc10a45
---
 include/linux/memcontrol.h |   3 +
 include/linux/mmzone.h     |   9 +
 include/linux/oom.h        |  16 +
 include/linux/swap.h       |   3 +
 mm/memcontrol.c            |   5 +
 mm/oom_kill.c              |   4 +-
 mm/rmap.c                  |   8 +
 mm/vmscan.c                | 948 +++++++++++++++++++++++++++++++++++++
 8 files changed, 994 insertions(+), 2 deletions(-)

--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1367,10 +1367,13 @@ mem_cgroup_print_oom_meminfo(struct mem_
 
 static inline void lock_page_memcg(struct page *page)
 {
+	/* to match page_memcg_rcu() */
+	rcu_read_lock();
 }
 
 static inline void unlock_page_memcg(struct page *page)
 {
+	rcu_read_unlock();
 }
 
 static inline void mem_cgroup_handle_over_high(void)
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -295,6 +295,7 @@ enum lruvec_flags {
 };
 
 struct lruvec;
+struct page_vma_mapped_walk;
 
 #define LRU_GEN_MASK		((BIT(LRU_GEN_WIDTH) - 1) << LRU_GEN_PGOFF)
 #define LRU_REFS_MASK		((BIT(LRU_REFS_WIDTH) - 1) << LRU_REFS_PGOFF)
@@ -393,6 +394,7 @@ struct mm_walk_args {
 
 void lru_gen_init_state(struct mem_cgroup *memcg, struct lruvec *lruvec);
 void lru_gen_change_state(bool enable, bool main, bool swap);
+void lru_gen_look_around(struct page_vma_mapped_walk *pvmw);
 
 #ifdef CONFIG_MEMCG
 void lru_gen_init_memcg(struct mem_cgroup *memcg);
@@ -409,6 +411,10 @@ static inline void lru_gen_change_state(
 {
 }
 
+static inline void lru_gen_look_around(struct page_vma_mapped_walk *pvmw)
+{
+}
+
 #ifdef CONFIG_MEMCG
 static inline void lru_gen_init_memcg(struct mem_cgroup *memcg)
 {
@@ -1028,6 +1034,9 @@ typedef struct pglist_data {
 
 	unsigned long		flags;
 
+#ifdef CONFIG_LRU_GEN
+	struct mm_walk_args	mm_walk_args;
+#endif
 	ZONE_PADDING(_pad2_)
 
 	/* Per-node vmstats */
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -57,6 +57,22 @@ struct oom_control {
 extern struct mutex oom_lock;
 extern struct mutex oom_adj_mutex;
 
+#ifdef CONFIG_MMU
+extern struct task_struct *oom_reaper_list;
+extern struct wait_queue_head oom_reaper_wait;
+
+static inline bool oom_reaping_in_progress(void)
+{
+	/* racy check to see if oom reaping could be in progress */
+	return READ_ONCE(oom_reaper_list) || !waitqueue_active(&oom_reaper_wait);
+}
+#else
+static inline bool oom_reaping_in_progress(void)
+{
+	return false;
+}
+#endif
+
 static inline void set_current_oom_origin(void)
 {
 	current->signal->oom_flag_origin = true;
--- a/include/linux/swap.h
+++ b/include/linux/swap.h
@@ -137,6 +137,9 @@ union swap_header {
  */
 struct reclaim_state {
 	unsigned long reclaimed_slab;
+#ifdef CONFIG_LRU_GEN
+	struct mm_walk_args *mm_walk_args;
+#endif
 };
 
 #ifdef __KERNEL__
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1304,12 +1304,17 @@ void mem_cgroup_update_lru_size(struct l
 		*lru_size += nr_pages;
 
 	size = *lru_size;
+#ifdef CONFIG_LRU_GEN
+	/* unlikely but not a bug when reset_batch_size() is pending */
+	VM_WARN_ON(size + MAX_BATCH_SIZE < 0);
+#else
 	if (WARN_ONCE(size < 0,
 		"%s(%p, %d, %d): lru_size %ld\n",
 		__func__, lruvec, lru, nr_pages, size)) {
 		VM_BUG_ON(1);
 		*lru_size = 0;
 	}
+#endif
 
 	if (nr_pages > 0)
 		*lru_size += nr_pages;
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -508,8 +508,8 @@ bool process_shares_mm(struct task_struc
  * victim (if that is possible) to help the OOM killer to move on.
  */
 static struct task_struct *oom_reaper_th;
-static DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
-static struct task_struct *oom_reaper_list;
+DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
+struct task_struct *oom_reaper_list;
 static DEFINE_SPINLOCK(oom_reaper_lock);
 
 bool __oom_reap_task_mm(struct mm_struct *mm)
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -73,6 +73,7 @@
 #include <linux/page_idle.h>
 #include <linux/memremap.h>
 #include <linux/userfaultfd_k.h>
+#include <linux/mm_inline.h>
 
 #include <asm/tlbflush.h>
 
@@ -793,6 +794,13 @@ static bool page_referenced_one(struct p
 		}
 
 		if (pvmw.pte) {
+			/* the multigenerational lru exploits the spatial locality */
+			if (lru_gen_enabled() && pte_young(*pvmw.pte) &&
+			    !(vma->vm_flags & VM_SEQ_READ)) {
+				lru_gen_look_around(&pvmw);
+				referenced++;
+			}
+
 			if (ptep_clear_flush_young_notify(vma, address,
 						pvmw.pte)) {
 				/*
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -51,6 +51,8 @@
 #include <linux/dax.h>
 #include <linux/psi.h>
 #include <linux/memory.h>
+#include <linux/pagewalk.h>
+#include <linux/shmem_fs.h>
 
 #include <asm/tlbflush.h>
 #include <asm/div64.h>
@@ -2822,6 +2824,15 @@ static bool can_age_anon_pages(struct pg
  *                          shorthand helpers
  ******************************************************************************/
 
+#define DEFINE_MAX_SEQ(lruvec)						\
+	unsigned long max_seq = READ_ONCE((lruvec)->evictable.max_seq)
+
+#define DEFINE_MIN_SEQ(lruvec)						\
+	unsigned long min_seq[ANON_AND_FILE] = {			\
+		READ_ONCE((lruvec)->evictable.min_seq[0]),		\
+		READ_ONCE((lruvec)->evictable.min_seq[1]),		\
+	}
+
 #define for_each_gen_type_zone(gen, type, zone)				\
 	for ((gen) = 0; (gen) < MAX_NR_GENS; (gen)++)			\
 		for ((type) = 0; (type) < ANON_AND_FILE; (type)++)	\
@@ -2834,6 +2845,12 @@ static int page_lru_gen(struct page *pag
 	return ((flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1;
 }
 
+static int get_swappiness(struct mem_cgroup *memcg)
+{
+	return mem_cgroup_get_nr_swap_pages(memcg) >= MIN_BATCH_SIZE ?
+	       mem_cgroup_swappiness(memcg) : 0;
+}
+
 static struct lruvec *get_lruvec(int nid, struct mem_cgroup *memcg)
 {
 	struct pglist_data *pgdat = NODE_DATA(nid);
@@ -3164,6 +3181,926 @@ done:
 }
 
 /******************************************************************************
+ *                          the aging
+ ******************************************************************************/
+
+static int page_update_gen(struct page *page, int gen)
+{
+	unsigned long old_flags, new_flags;
+
+	VM_BUG_ON(gen >= MAX_NR_GENS);
+
+	do {
+		new_flags = old_flags = READ_ONCE(page->flags);
+
+		if (!(new_flags & LRU_GEN_MASK)) {
+			new_flags |= BIT(PG_referenced);
+			continue;
+		}
+
+		new_flags &= ~LRU_GEN_MASK;
+		new_flags |= (gen + 1UL) << LRU_GEN_PGOFF;
+	} while (new_flags != old_flags &&
+		 cmpxchg(&page->flags, old_flags, new_flags) != old_flags);
+
+	return ((old_flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1;
+}
+
+static void page_inc_gen(struct page *page, struct lruvec *lruvec, bool reclaiming)
+{
+	int old_gen, new_gen;
+	unsigned long old_flags, new_flags;
+	int type = page_is_file_lru(page);
+	int zone = page_zonenum(page);
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	old_gen = lru_gen_from_seq(lrugen->min_seq[type]);
+
+	do {
+		new_flags = old_flags = READ_ONCE(page->flags);
+		VM_BUG_ON_PAGE(!(new_flags & LRU_GEN_MASK), page);
+
+		new_gen = ((new_flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1;
+		/* page_update_gen() has updated this page? */
+		if (new_gen >= 0 && new_gen != old_gen) {
+			list_move(&page->lru, &lrugen->lists[new_gen][type][zone]);
+			return;
+		}
+
+		new_gen = (old_gen + 1) % MAX_NR_GENS;
+
+		new_flags &= ~LRU_GEN_MASK;
+		new_flags |= (new_gen + 1UL) << LRU_GEN_PGOFF;
+		/* for end_page_writeback() */
+		if (reclaiming)
+			new_flags |= BIT(PG_reclaim);
+	} while (cmpxchg(&page->flags, old_flags, new_flags) != old_flags);
+
+	lru_gen_update_size(page, lruvec, old_gen, new_gen);
+	if (reclaiming)
+		list_move(&page->lru, &lrugen->lists[new_gen][type][zone]);
+	else
+		list_move_tail(&page->lru, &lrugen->lists[new_gen][type][zone]);
+}
+
+static void update_batch_size(struct page *page, int old_gen, int new_gen,
+			      struct mm_walk_args *args)
+{
+	int type = page_is_file_lru(page);
+	int zone = page_zonenum(page);
+	int delta = thp_nr_pages(page);
+
+	VM_BUG_ON(old_gen >= MAX_NR_GENS);
+	VM_BUG_ON(new_gen >= MAX_NR_GENS);
+
+	args->batch_size++;
+
+	args->nr_pages[old_gen][type][zone] -= delta;
+	args->nr_pages[new_gen][type][zone] += delta;
+}
+
+static void reset_batch_size(struct lruvec *lruvec, struct mm_walk_args *args)
+{
+	int gen, type, zone;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	args->batch_size = 0;
+
+	for_each_gen_type_zone(gen, type, zone) {
+		enum lru_list lru = type * LRU_FILE;
+		int delta = args->nr_pages[gen][type][zone];
+
+		if (!delta)
+			continue;
+
+		args->nr_pages[gen][type][zone] = 0;
+		WRITE_ONCE(lrugen->sizes[gen][type][zone],
+			   lrugen->sizes[gen][type][zone] + delta);
+
+		if (lru_gen_is_active(lruvec, gen))
+			lru += LRU_ACTIVE;
+		update_lru_size(lruvec, lru, zone, delta);
+	}
+}
+
+static int should_skip_vma(unsigned long start, unsigned long end, struct mm_walk *walk)
+{
+	struct address_space *mapping;
+	struct vm_area_struct *vma = walk->vma;
+	struct mm_walk_args *args = walk->private;
+
+	if (!vma_is_accessible(vma) || is_vm_hugetlb_page(vma) ||
+	    (vma->vm_flags & (VM_LOCKED | VM_SPECIAL | VM_SEQ_READ)))
+		return true;
+
+	if (vma_is_anonymous(vma))
+		return !args->swappiness;
+
+	if (WARN_ON_ONCE(!vma->vm_file || !vma->vm_file->f_mapping))
+		return true;
+
+	mapping = vma->vm_file->f_mapping;
+	if (!mapping->a_ops->writepage)
+		return true;
+
+	return (shmem_mapping(mapping) && !args->swappiness) || mapping_unevictable(mapping);
+}
+
+/*
+ * Some userspace memory allocators create many single-page VMAs. So instead of
+ * returning back to the PGD table for each of such VMAs, we finish at least an
+ * entire PMD table and therefore avoid many zigzags.
+ */
+static bool get_next_vma(struct mm_walk *walk, unsigned long mask, unsigned long size,
+			 unsigned long *start, unsigned long *end)
+{
+	unsigned long next = round_up(*end, size);
+
+	VM_BUG_ON(mask & size);
+	VM_BUG_ON(*start >= *end);
+	VM_BUG_ON((next & mask) != (*start & mask));
+
+	while (walk->vma) {
+		if (next >= walk->vma->vm_end) {
+			walk->vma = walk->vma->vm_next;
+			continue;
+		}
+
+		if ((next & mask) != (walk->vma->vm_start & mask))
+			return false;
+
+		if (should_skip_vma(walk->vma->vm_start, walk->vma->vm_end, walk)) {
+			walk->vma = walk->vma->vm_next;
+			continue;
+		}
+
+		*start = max(next, walk->vma->vm_start);
+		next = (next | ~mask) + 1;
+		/* rounded-up boundaries can wrap to 0 */
+		*end = next && next < walk->vma->vm_end ? next : walk->vma->vm_end;
+
+		return true;
+	}
+
+	return false;
+}
+
+static bool walk_pte_range(pmd_t *pmd, unsigned long start, unsigned long end,
+			   struct mm_walk *walk)
+{
+	int i;
+	pte_t *pte;
+	spinlock_t *ptl;
+	unsigned long addr;
+	int worth = 0;
+	struct mm_walk_args *args = walk->private;
+	int old_gen, new_gen = lru_gen_from_seq(args->max_seq);
+
+	VM_BUG_ON(pmd_leaf(*pmd));
+
+	pte = pte_offset_map_lock(walk->mm, pmd, start & PMD_MASK, &ptl);
+	arch_enter_lazy_mmu_mode();
+restart:
+	for (i = pte_index(start), addr = start; addr != end; i++, addr += PAGE_SIZE) {
+		struct page *page;
+		unsigned long pfn = pte_pfn(pte[i]);
+
+		args->mm_stats[MM_LEAF_TOTAL]++;
+
+		if (!pte_present(pte[i]) || is_zero_pfn(pfn))
+			continue;
+
+		if (WARN_ON_ONCE(pte_devmap(pte[i]) || pte_special(pte[i])))
+			continue;
+
+		if (!pte_young(pte[i])) {
+			args->mm_stats[MM_LEAF_OLD]++;
+			continue;
+		}
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < args->start_pfn || pfn >= args->end_pfn)
+			continue;
+
+		page = compound_head(pfn_to_page(pfn));
+		if (page_to_nid(page) != args->node_id)
+			continue;
+
+		if (page_memcg_rcu(page) != args->memcg)
+			continue;
+
+		VM_BUG_ON(addr < walk->vma->vm_start || addr >= walk->vma->vm_end);
+		if (!ptep_test_and_clear_young(walk->vma, addr, pte + i))
+			continue;
+
+		args->mm_stats[MM_LEAF_YOUNG]++;
+
+		if (pte_dirty(pte[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page)))
+			set_page_dirty(page);
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			update_batch_size(page, old_gen, new_gen, args);
+
+		worth++;
+	}
+
+	if (i < PTRS_PER_PTE && get_next_vma(walk, PMD_MASK, PAGE_SIZE, &start, &end))
+		goto restart;
+
+	arch_leave_lazy_mmu_mode();
+	pte_unmap_unlock(pte, ptl);
+
+	return worth >= MIN_BATCH_SIZE / 2;
+}
+
+/*
+ * We scan PMD entries in two passes. The first pass reaches to PTE tables and
+ * doesn't take the PMD lock. The second pass clears the accessed bit on PMD
+ * entries and needs to take the PMD lock.
+ */
+#if defined(CONFIG_TRANSPARENT_HUGEPAGE) || defined(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG)
+static void walk_pmd_range_locked(pud_t *pud, unsigned long start, int offset,
+				  struct vm_area_struct *vma, struct mm_walk *walk)
+{
+	int i;
+	pmd_t *pmd;
+	spinlock_t *ptl;
+	struct mm_walk_args *args = walk->private;
+	int old_gen, new_gen = lru_gen_from_seq(args->max_seq);
+
+	VM_BUG_ON(pud_leaf(*pud));
+
+	start = (start & PUD_MASK) + offset * PMD_SIZE;
+	pmd = pmd_offset(pud, start);
+	ptl = pmd_lock(walk->mm, pmd);
+	arch_enter_lazy_mmu_mode();
+
+	for_each_set_bit(i, args->bitmap, MIN_BATCH_SIZE) {
+		struct page *page;
+		unsigned long pfn = pmd_pfn(pmd[i]);
+		unsigned long addr = start + i * PMD_SIZE;
+
+		if (!pmd_present(pmd[i]) || is_huge_zero_pmd(pmd[i]))
+			continue;
+
+		if (WARN_ON_ONCE(pmd_devmap(pmd[i])))
+			continue;
+
+		if (!pmd_trans_huge(pmd[i])) {
+			if (IS_ENABLED(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG))
+				pmdp_test_and_clear_young(vma, addr, pmd + i);
+			continue;
+		}
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < args->start_pfn || pfn >= args->end_pfn)
+			continue;
+
+		page = pfn_to_page(pfn);
+		VM_BUG_ON_PAGE(PageTail(page), page);
+		if (page_to_nid(page) != args->node_id)
+			continue;
+
+		if (page_memcg_rcu(page) != args->memcg)
+			continue;
+
+		VM_BUG_ON(addr < vma->vm_start || addr >= vma->vm_end);
+		if (!pmdp_test_and_clear_young(vma, addr, pmd + i))
+			continue;
+
+		args->mm_stats[MM_LEAF_YOUNG]++;
+
+		if (pmd_dirty(pmd[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page)))
+			set_page_dirty(page);
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			update_batch_size(page, old_gen, new_gen, args);
+	}
+
+	arch_leave_lazy_mmu_mode();
+	spin_unlock(ptl);
+
+	bitmap_zero(args->bitmap, MIN_BATCH_SIZE);
+}
+#else
+static void walk_pmd_range_locked(pud_t *pud, unsigned long start, int offset,
+				  struct vm_area_struct *vma, struct mm_walk *walk)
+{
+}
+#endif
+
+static void walk_pmd_range(pud_t *pud, unsigned long start, unsigned long end,
+			   struct mm_walk *walk)
+{
+	int i;
+	pmd_t *pmd;
+	unsigned long next;
+	unsigned long addr;
+	struct vm_area_struct *vma;
+	int offset = -1;
+	bool reset = false;
+	struct mm_walk_args *args = walk->private;
+	struct lruvec *lruvec = get_lruvec(args->node_id, args->memcg);
+
+	VM_BUG_ON(pud_leaf(*pud));
+
+	pmd = pmd_offset(pud, start & PUD_MASK);
+restart:
+	vma = walk->vma;
+	for (i = pmd_index(start), addr = start; addr != end; i++, addr = next) {
+		pmd_t val = pmd_read_atomic(pmd + i);
+
+		/* for pmd_read_atomic() */
+		barrier();
+
+		next = pmd_addr_end(addr, end);
+
+		if (!pmd_present(val)) {
+			args->mm_stats[MM_LEAF_TOTAL]++;
+			continue;
+		}
+
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+		if (pmd_trans_huge(val)) {
+			unsigned long pfn = pmd_pfn(val);
+
+			args->mm_stats[MM_LEAF_TOTAL]++;
+
+			if (is_huge_zero_pmd(val))
+				continue;
+
+			if (!pmd_young(val)) {
+				args->mm_stats[MM_LEAF_OLD]++;
+				continue;
+			}
+
+			if (pfn < args->start_pfn || pfn >= args->end_pfn)
+				continue;
+
+			if (offset < 0)
+				offset = i;
+			else if (i - offset >= MIN_BATCH_SIZE) {
+				walk_pmd_range_locked(pud, start, offset, vma, walk);
+				offset = i;
+			}
+			__set_bit(i - offset, args->bitmap);
+			reset = true;
+			continue;
+		}
+#endif
+		args->mm_stats[MM_NONLEAF_TOTAL]++;
+
+#ifdef CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG
+		if (!pmd_young(val))
+			continue;
+
+		if (offset < 0)
+			offset = i;
+		else if (i - offset >= MIN_BATCH_SIZE) {
+			walk_pmd_range_locked(pud, start, offset, vma, walk);
+			offset = i;
+			reset = false;
+		}
+		__set_bit(i - offset, args->bitmap);
+#endif
+		if (args->use_filter && !test_bloom_filter(lruvec, args->max_seq, pmd + i))
+			continue;
+
+		args->mm_stats[MM_NONLEAF_PREV]++;
+
+		if (!walk_pte_range(&val, addr, next, walk))
+			continue;
+
+		args->mm_stats[MM_NONLEAF_CUR]++;
+
+		set_bloom_filter(lruvec, args->max_seq + 1, pmd + i);
+	}
+
+	if (reset) {
+		walk_pmd_range_locked(pud, start, offset, vma, walk);
+		offset = -1;
+		reset = false;
+	}
+
+	if (i < PTRS_PER_PMD && get_next_vma(walk, PUD_MASK, PMD_SIZE, &start, &end))
+		goto restart;
+
+	if (offset >= 0)
+		walk_pmd_range_locked(pud, start, offset, vma, walk);
+}
+
+static int walk_pud_range(p4d_t *p4d, unsigned long start, unsigned long end,
+			  struct mm_walk *walk)
+{
+	int i;
+	pud_t *pud;
+	unsigned long addr;
+	unsigned long next;
+	struct mm_walk_args *args = walk->private;
+
+	VM_BUG_ON(p4d_leaf(*p4d));
+
+	pud = pud_offset(p4d, start & P4D_MASK);
+restart:
+	for (i = pud_index(start), addr = start; addr != end; i++, addr = next) {
+		pud_t val = READ_ONCE(pud[i]);
+
+		next = pud_addr_end(addr, end);
+
+		if (!pud_present(val) || WARN_ON_ONCE(pud_leaf(val)))
+			continue;
+
+		walk_pmd_range(&val, addr, next, walk);
+
+		if (args->batch_size >= MAX_BATCH_SIZE) {
+			end = (addr | ~PUD_MASK) + 1;
+			goto done;
+		}
+	}
+
+	if (i < PTRS_PER_PUD && get_next_vma(walk, P4D_MASK, PUD_SIZE, &start, &end))
+		goto restart;
+
+	end = round_up(end, P4D_SIZE);
+done:
+	/* rounded-up boundaries can wrap to 0 */
+	args->next_addr = end && walk->vma ? max(end, walk->vma->vm_start) : 0;
+
+	return -EAGAIN;
+}
+
+static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct mm_walk_args *args)
+{
+	static const struct mm_walk_ops mm_walk_ops = {
+		.test_walk = should_skip_vma,
+		.p4d_entry = walk_pud_range,
+	};
+
+	int err;
+
+	args->next_addr = FIRST_USER_ADDRESS;
+
+	do {
+		unsigned long start = args->next_addr;
+		unsigned long end = mm->highest_vm_end;
+
+		err = -EBUSY;
+
+		rcu_read_lock();
+#ifdef CONFIG_MEMCG
+		if (args->memcg && atomic_read(&args->memcg->moving_account))
+			goto contended;
+#endif
+		if (!mmap_read_trylock(mm))
+			goto contended;
+
+		err = walk_page_range(mm, start, end, &mm_walk_ops, args);
+
+		mmap_read_unlock(mm);
+
+		if (args->batch_size) {
+			spin_lock_irq(&lruvec->lru_lock);
+			reset_batch_size(lruvec, args);
+			spin_unlock_irq(&lruvec->lru_lock);
+		}
+contended:
+		rcu_read_unlock();
+
+		cond_resched();
+	} while (err == -EAGAIN && args->next_addr && !mm_is_oom_victim(mm));
+}
+
+static struct mm_walk_args *alloc_mm_walk_args(void)
+{
+	if (!current->reclaim_state || !current->reclaim_state->mm_walk_args)
+		return kvzalloc(sizeof(struct mm_walk_args), GFP_KERNEL);
+
+	return current->reclaim_state->mm_walk_args;
+}
+
+static void free_mm_walk_args(struct mm_walk_args *args)
+{
+	if (!current->reclaim_state || !current->reclaim_state->mm_walk_args)
+		kvfree(args);
+}
+
+static bool inc_min_seq(struct lruvec *lruvec, int type)
+{
+	int gen, zone;
+	int remaining = MAX_BATCH_SIZE;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	if (get_nr_gens(lruvec, type) != MAX_NR_GENS)
+		return true;
+
+	gen = lru_gen_from_seq(lrugen->min_seq[type]);
+
+	for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+		struct list_head *head = &lrugen->lists[gen][type][zone];
+
+		while (!list_empty(head)) {
+			struct page *page = lru_to_page(head);
+
+			VM_BUG_ON_PAGE(PageTail(page), page);
+			VM_BUG_ON_PAGE(PageUnevictable(page), page);
+			VM_BUG_ON_PAGE(PageActive(page), page);
+			VM_BUG_ON_PAGE(page_is_file_lru(page) != type, page);
+			VM_BUG_ON_PAGE(page_zonenum(page) != zone, page);
+
+			prefetchw_prev_lru_page(page, head, flags);
+
+			page_inc_gen(page, lruvec, false);
+
+			if (!--remaining)
+				return false;
+		}
+	}
+
+	WRITE_ONCE(lrugen->min_seq[type], lrugen->min_seq[type] + 1);
+
+	return true;
+}
+
+static bool try_to_inc_min_seq(struct lruvec *lruvec, int swappiness)
+{
+	int gen, type, zone;
+	bool success = false;
+	struct lrugen *lrugen = &lruvec->evictable;
+	DEFINE_MIN_SEQ(lruvec);
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	for (type = 0; type < ANON_AND_FILE; type++) {
+		while (lrugen->max_seq - min_seq[type] >= MIN_NR_GENS) {
+			gen = lru_gen_from_seq(min_seq[type]);
+
+			for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+				if (!list_empty(&lrugen->lists[gen][type][zone]))
+					goto next;
+			}
+
+			min_seq[type]++;
+		}
+next:
+		;
+	}
+
+	min_seq[0] = min(min_seq[0], min_seq[1]);
+	if (swappiness)
+		min_seq[1] = max(min_seq[0], lrugen->min_seq[1]);
+
+	for (type = 0; type < ANON_AND_FILE; type++) {
+		if (min_seq[type] == lrugen->min_seq[type])
+			continue;
+
+		WRITE_ONCE(lrugen->min_seq[type], min_seq[type]);
+		success = true;
+	}
+
+	return success;
+}
+
+static void inc_max_seq(struct lruvec *lruvec, unsigned long max_seq)
+{
+	int gen, type, zone;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	spin_lock_irq(&lruvec->lru_lock);
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	if (max_seq != lrugen->max_seq)
+		goto unlock;
+
+	if (!try_to_inc_min_seq(lruvec, true)) {
+		for (type = ANON_AND_FILE - 1; type >= 0; type--) {
+			while (!inc_min_seq(lruvec, type)) {
+				spin_unlock_irq(&lruvec->lru_lock);
+				cond_resched();
+				spin_lock_irq(&lruvec->lru_lock);
+			}
+		}
+	}
+
+	gen = lru_gen_from_seq(lrugen->max_seq - 1);
+	for (type = 0; type < ANON_AND_FILE; type++) {
+		for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+			enum lru_list lru = type * LRU_FILE;
+			long delta = lrugen->sizes[gen][type][zone];
+
+			if (!delta)
+				continue;
+
+			WARN_ON_ONCE(delta != (int)delta);
+
+			update_lru_size(lruvec, lru, zone, delta);
+			update_lru_size(lruvec, lru + LRU_ACTIVE, zone, -delta);
+		}
+	}
+
+	gen = lru_gen_from_seq(lrugen->max_seq + 1);
+	for (type = 0; type < ANON_AND_FILE; type++) {
+		for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+			enum lru_list lru = type * LRU_FILE;
+			long delta = lrugen->sizes[gen][type][zone];
+
+			if (!delta)
+				continue;
+
+			WARN_ON_ONCE(delta != (int)delta);
+
+			update_lru_size(lruvec, lru, zone, -delta);
+			update_lru_size(lruvec, lru + LRU_ACTIVE, zone, delta);
+		}
+	}
+
+	WRITE_ONCE(lrugen->timestamps[gen], jiffies);
+	/* make sure all preceding modifications appear first */
+	smp_store_release(&lrugen->max_seq, lrugen->max_seq + 1);
+unlock:
+	spin_unlock_irq(&lruvec->lru_lock);
+}
+
+/* Main function used by the foreground, the background and the user-triggered aging. */
+static bool try_to_inc_max_seq(struct lruvec *lruvec, struct scan_control *sc, int swappiness,
+			       unsigned long max_seq, bool use_filter)
+{
+	bool last;
+	struct mm_walk_args *args;
+	struct mm_struct *mm = NULL;
+	struct lrugen *lrugen = &lruvec->evictable;
+	struct mem_cgroup *memcg = lruvec_memcg(lruvec);
+	struct pglist_data *pgdat = lruvec_pgdat(lruvec);
+	int nid = pgdat->node_id;
+
+	VM_BUG_ON(max_seq > READ_ONCE(lrugen->max_seq));
+
+	/*
+	 * If we are not from run_aging() and clearing the accessed bit may
+	 * trigger page faults, then don't proceed to clearing all accessed
+	 * PTEs. Instead, fallback to lru_gen_look_around(), which only clears a
+	 * handful of accessed PTEs. This is less efficient but causes fewer
+	 * page faults on CPUs that don't have the capability.
+	 */
+	if ((current->flags & PF_MEMALLOC) && !arch_has_hw_pte_young(false)) {
+		inc_max_seq(lruvec, max_seq);
+		return true;
+	}
+
+	args = alloc_mm_walk_args();
+	if (!args)
+		return false;
+
+	args->memcg = memcg;
+	args->max_seq = max_seq;
+	args->start_pfn = pgdat->node_start_pfn;
+	args->end_pfn = pgdat_end_pfn(pgdat);
+	args->node_id = nid;
+	args->swappiness = swappiness;
+	args->use_filter = use_filter;
+
+	do {
+		last = get_next_mm(lruvec, args, &mm);
+		if (mm)
+			walk_mm(lruvec, mm, args);
+
+		cond_resched();
+	} while (mm);
+
+	free_mm_walk_args(args);
+
+	if (!last) {
+		/* don't wait unless we may have trouble reclaiming */
+		if (!current_is_kswapd() && sc->priority < DEF_PRIORITY - 2)
+			wait_event_killable(lruvec->mm_walk.wait,
+					    max_seq < READ_ONCE(lrugen->max_seq));
+
+		return max_seq < READ_ONCE(lrugen->max_seq);
+	}
+
+	VM_BUG_ON(max_seq != READ_ONCE(lrugen->max_seq));
+
+	inc_max_seq(lruvec, max_seq);
+	/* either we see any waiters or they will see updated max_seq */
+	if (wq_has_sleeper(&lruvec->mm_walk.wait))
+		wake_up_all(&lruvec->mm_walk.wait);
+
+	wakeup_flusher_threads(WB_REASON_VMSCAN);
+
+	return true;
+}
+
+static long get_nr_evictable(struct lruvec *lruvec, struct scan_control *sc, int swappiness,
+			     unsigned long max_seq, unsigned long *min_seq, bool *low)
+{
+	int gen, type, zone;
+	long max = 0;
+	long min = 0;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	for (type = !swappiness; type < ANON_AND_FILE; type++) {
+		unsigned long seq;
+
+		for (seq = min_seq[type]; seq <= max_seq; seq++) {
+			long size = 0;
+
+			gen = lru_gen_from_seq(seq);
+
+			for (zone = 0; zone <= sc->reclaim_idx; zone++)
+				size += READ_ONCE(lrugen->sizes[gen][type][zone]);
+
+			max += size;
+			if (type && max_seq - seq >= MIN_NR_GENS)
+				min += size;
+		}
+	}
+
+	*low = max_seq - min_seq[1] <= MIN_NR_GENS && min < MIN_BATCH_SIZE;
+
+	return max > 0 ? max : 0;
+}
+
+static bool age_lruvec(struct lruvec *lruvec, struct scan_control *sc,
+		       unsigned long min_ttl)
+{
+	bool low;
+	long nr_to_scan;
+	struct mem_cgroup *memcg = lruvec_memcg(lruvec);
+	int swappiness = get_swappiness(memcg);
+	DEFINE_MAX_SEQ(lruvec);
+	DEFINE_MIN_SEQ(lruvec);
+
+	if (mem_cgroup_below_min(memcg))
+		return false;
+
+	if (min_ttl) {
+		int gen = lru_gen_from_seq(min_seq[1]);
+		unsigned long birth = READ_ONCE(lruvec->evictable.timestamps[gen]);
+
+		if (time_is_after_jiffies(birth + min_ttl))
+			return false;
+	}
+
+	nr_to_scan = get_nr_evictable(lruvec, sc, swappiness, max_seq, min_seq, &low);
+	if (!nr_to_scan)
+		return false;
+
+	nr_to_scan >>= sc->priority;
+
+	if (!mem_cgroup_online(memcg))
+		nr_to_scan++;
+
+	if (nr_to_scan && low && (!mem_cgroup_below_low(memcg) || sc->memcg_low_reclaim))
+		try_to_inc_max_seq(lruvec, sc, swappiness, max_seq, true);
+
+	return true;
+}
+
+/* Protect the working set accessed within the last N milliseconds. */
+static unsigned long lru_gen_min_ttl __read_mostly;
+
+static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
+{
+	struct mem_cgroup *memcg;
+	bool success = false;
+	unsigned long min_ttl = READ_ONCE(lru_gen_min_ttl);
+
+	VM_BUG_ON(!current_is_kswapd());
+
+	if (!sc->force_deactivate) {
+		sc->force_deactivate = 1;
+		return;
+	}
+
+	current->reclaim_state->mm_walk_args = &pgdat->mm_walk_args;
+
+	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	do {
+		struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
+
+		if (age_lruvec(lruvec, sc, min_ttl))
+			success = true;
+
+		cond_resched();
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+
+	if (!success && mutex_trylock(&oom_lock)) {
+		struct oom_control oc = {
+			.gfp_mask = sc->gfp_mask,
+			.order = sc->order,
+		};
+
+		/* to avoid overkilling */
+		if (!oom_reaping_in_progress())
+			out_of_memory(&oc);
+
+		mutex_unlock(&oom_lock);
+	}
+
+	current->reclaim_state->mm_walk_args = NULL;
+}
+
+/* Scan the vicinity of an accessed PTE when shrink_page_list() uses the rmap. */
+void lru_gen_look_around(struct page_vma_mapped_walk *pvmw)
+{
+	int i;
+	pte_t *pte;
+	struct page *page;
+	int old_gen, new_gen;
+	unsigned long start;
+	unsigned long end;
+	unsigned long addr;
+	struct mm_walk_args *args;
+	int worth = 0;
+	struct mem_cgroup *memcg = page_memcg(pvmw->page);
+	struct pglist_data *pgdat = page_pgdat(pvmw->page);
+	struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
+	DEFINE_MAX_SEQ(lruvec);
+
+	lockdep_assert_held(pvmw->ptl);
+	VM_BUG_ON_PAGE(PageLRU(pvmw->page), pvmw->page);
+
+	args = current->reclaim_state ? current->reclaim_state->mm_walk_args : NULL;
+	if (!args)
+		return;
+
+	start = max(pvmw->address & PMD_MASK, pvmw->vma->vm_start);
+	end = min(pvmw->address | ~PMD_MASK, pvmw->vma->vm_end - 1) + 1;
+
+	if (end - start > MIN_BATCH_SIZE * PAGE_SIZE) {
+		if (pvmw->address - start < MIN_BATCH_SIZE * PAGE_SIZE / 2)
+			end = start + MIN_BATCH_SIZE * PAGE_SIZE;
+		else if (end - pvmw->address < MIN_BATCH_SIZE * PAGE_SIZE / 2)
+			start = end - MIN_BATCH_SIZE * PAGE_SIZE;
+		else {
+			start = pvmw->address - MIN_BATCH_SIZE * PAGE_SIZE / 2;
+			end = pvmw->address + MIN_BATCH_SIZE * PAGE_SIZE / 2;
+		}
+	}
+
+	pte = pvmw->pte - (pvmw->address - start) / PAGE_SIZE;
+	new_gen = lru_gen_from_seq(max_seq);
+
+	lock_page_memcg(pvmw->page);
+	arch_enter_lazy_mmu_mode();
+
+	for (i = 0, addr = start; addr != end; i++, addr += PAGE_SIZE) {
+		unsigned long pfn = pte_pfn(pte[i]);
+
+		if (!pte_present(pte[i]) || is_zero_pfn(pfn))
+			continue;
+
+		if (WARN_ON_ONCE(pte_devmap(pte[i]) || pte_special(pte[i])))
+			continue;
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
+			continue;
+
+		worth++;
+
+		if (!pte_young(pte[i]))
+			continue;
+
+		page = compound_head(pfn_to_page(pfn));
+		if (page_to_nid(page) != pgdat->node_id)
+			continue;
+
+		if (page_memcg_rcu(page) != memcg)
+			continue;
+
+		VM_BUG_ON(addr < pvmw->vma->vm_start || addr >= pvmw->vma->vm_end);
+		if (!ptep_test_and_clear_young(pvmw->vma, addr, pte + i))
+			continue;
+
+		if (pte_dirty(pte[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page)))
+			__set_bit(i, args->bitmap);
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			update_batch_size(page, old_gen, new_gen, args);
+	}
+
+	arch_leave_lazy_mmu_mode();
+	unlock_page_memcg(pvmw->page);
+
+	if (worth >= MIN_BATCH_SIZE / 2)
+		set_bloom_filter(lruvec, max_seq, pvmw->pmd);
+
+	for_each_set_bit(i, args->bitmap, MIN_BATCH_SIZE)
+		set_page_dirty(pte_page(pte[i]));
+
+	bitmap_zero(args->bitmap, MIN_BATCH_SIZE);
+}
+
+/******************************************************************************
  *                          state change
  ******************************************************************************/
 
@@ -3412,6 +4349,12 @@ static int __init init_lru_gen(void)
 };
 late_initcall(init_lru_gen);
 
+#else
+
+static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
+{
+}
+
 #endif /* CONFIG_LRU_GEN */
 
 static void shrink_lruvec(struct lruvec *lruvec, struct scan_control *sc)
@@ -4266,6 +5209,11 @@ static void age_active_anon(struct pglis
 	struct mem_cgroup *memcg;
 	struct lruvec *lruvec;
 
+	if (lru_gen_enabled()) {
+		lru_gen_age_node(pgdat, sc);
+		return;
+	}
+
 	if (!can_age_anon_pages(pgdat, sc))
 		return;