Source Code
#include <bits/stdc++.h>

using namespace std ;

const int MAX = 2e5 + 10 ;

int P[MAX][20] , Max[MAX][20] , Min[MAX][20] ;
int dep[MAX] ;

vector< vector<int> >adj(MAX) ;

int n , q ;

void dfs(int node)
{
	Max[node][0] = -1 , Min[node][0] = 1e9 ;
	if(adj[node].size() > 1)
		Max[node][0] = dep[node] , Min[node][0] = dep[node] ;
	for(int j = 1 ; j < 18 ; ++j)
	{
		P[node][j] = P[P[node][j-1]][j-1] ;
		Max[node][j] = max(Max[node][j-1] , Max[P[node][j-1]][j-1]) ;
		Min[node][j] = min(Min[node][j-1] , Min[P[node][j-1]][j-1]) ;
	}
	for(auto &child : adj[node])
	{
		dep[child] = dep[node] + 1 ;
		dfs(child) ;
	}
}

int LCA(int x , int y)
{
	if(dep[x] < dep[y])
		swap(x , y) ;
	for(int j = 17 ; j >= 0 ; --j)
	{
		if(dep[x] - (1 << j) >= dep[y])
			x = P[x][j] ;
	}
	if(x == y)
		return x ;
	for(int j = 17 ; j >= 0 ; --j)
	{
		if(P[x][j] != P[y][j])
			x = P[x][j] , y = P[y][j] ;
	}
	return P[x][0] ;
}

int FindMax(int x , int y)
{
	int a = -1 ;
	for(int j = 17 ; j >= 0 ; --j)
	{
		if(dep[x] - (1 << j) >= dep[y])
			a = max(a , Max[x][j]) , x = P[x][j] ;
	}
	return a ;
}

int FindMin(int x , int y)
{
	int a = 1e9 ;
	for(int j = 17 ; j >= 0 ; --j)
	{
		if(dep[x] - (1 << j) >= dep[y])
			a = min(a , Min[x][j]) , x = P[x][j] ;
	}
	if(a == 1e9)
		a = -1 ;
	return a ;
}

int main()
{
	ios_base::sync_with_stdio(0) ;
	cin.tie(0) ;
	cin>>n>>q ;
	for(int i = 2 ; i <= n ; ++i)
	{
		cin>>P[i][0] ;
		adj[P[i][0]].push_back(i) ;
	}
	dfs(1) ;
	while(q--)
	{
		int a , b ;
		cin>>a>>b ;
		if(a == b)
		{
			cout<<"yes\n" ;
			continue ;
		}
		if(adj[b].size() > 0)
		{
			cout<<"yes\n" ;
			continue ;
		}
		int lca = LCA(a , b) ;
		int x = FindMax(b , lca) , dist1 = 0 , dist2 = 0 ;
		if(x != -1)
			dist1 = dep[a] + x - 2 * dep[lca] , dist2 = dep[b] - x ;
		else if(adj[lca].size() > 2 || lca != 1 || (adj[lca].size() > 1 && (lca == a || lca == b)))
			dist1 = dep[a] - dep[lca] , dist2 = dep[b] - dep[lca] ;
		else if(FindMin(a , lca) != -1)
			dist1 = dep[a] - FindMin(a , lca) , dist2 = dep[b] + FindMin(a , lca) - 2 * dep[lca] ;
		if(dist1 > dist2)
			cout<<"yes\n" ;
		else
			cout<<"no\n" ;
	}
	return 0 ;
}
Copy
Escape from TarkZoo Bakry_
GNU G++17
58 ms
29.3 MB
Wrong Answer