P5043 【模板】树同构([BJOI2015]树的同构)

P5043 【模板】树同构([BJOI2015]树的同构)

P5043 【模板】树同构([BJOI2015]树的同构)
思路:树hash,先找树重心,重心最多两个,然后从以重心为根求出树的hash值,放进map里。
代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head
 
const ULL B = 133;
int n, m, f, sz[55], hz;
vector<int> g[55], rt;
unordered_map<ULL, int> mp0;
map<pair<ULL, ULL>, int> mp1;
void get_h(int u, int o) {
	sz[u] = 1;
	int mx = 0;
	for (int v : g[u]) {
		if(v != o){
			get_h(v, u);
			sz[u] += sz[v];
			mx = max(mx, sz[v]);	
		}
	}
	mx = max(mx, n-sz[u]);
	if(mx < hz) vector<int>().swap(rt), rt.pb(u), hz = mx;
	else if(mx == hz) rt.pb(u);
}
ULL dfs(int u, int o) {
	sz[u] = 1;
	vector<ULL> vc;
	for (int v : g[u]) {
		if(v != o) {
			vc.pb(dfs(v, u));
			sz[u] += sz[v];
		} 
	}
	if(vc.size() == 0) return 1;
	sort(vc.begin(), vc.end());
	ULL t = 1, sum = 0;
	for (ULL x:vc) {
		sum += t*x;
		t *= B; 
	} 
	return sum*sz[u];
}
int main() {
	scanf("%d", &m);
	for (int cs = 1; cs <= m; ++cs) {
		scanf("%d", &n);
		int r;
		for (int i = 1; i <= n; ++i) {
			scanf("%d", &f); 	
			if(!f) r = i;
			else g[f].pb(i), g[i].pb(f); 
		}
		hz = n;
		get_h(r, r);
		assert(rt.size() <= 2);
		assert(rt.size() >= 1); 
		if(rt.size() == 2) {
			ULL a = dfs(rt[0], rt[0]);
			ULL b = dfs(rt[1], rt[1]);
			if(a > b) swap(a, b);
			if(mp1.find({a, b}) != mp1.end()) printf("%d
", mp1[{a, b}]);
			else mp1[{a, b}] = cs, printf("%d
", cs); 
		} 
		else {
			ULL a = dfs(rt[0], rt[0]);
			if(mp0.find(a) != mp0.end()) printf("%d
", mp0[a]);
			else mp0[a] = cs, printf("%d
", cs); 
		} 
		vector<int>().swap(rt);
		for (int i = 1; i <= n; ++i) vector<int>().swap(g[i]);
	}
	return 0;
}