[BZOJ 2157] 旅游
切完题来水一发……
拥有良心样例数据的树链剖分模版题……
窝已经要突破天际了哈哈哈写这货都能有 250+ 行哈哈哈已经开始脑补 NOI 2005 维修数列了哈哈哈哈哈哈……(←发抽 ing)
/************************************************************** Problem: 2157 User: cyand1317 Language: C++ Result: Accepted Time:632 ms Memory:3472 kb ****************************************************************/ #include <cstdio> #include <utility> #include <vector> static const int MAXN = 20006; static const int INF = 0x7fffffff; inline int min(int a, int b) { return a < b ? a : b; } inline int max(int a, int b) { return a > b ? a : b; } inline void swap(int &a, int &b) { static int t; t = a; a = b; b = t; } struct __sgt_node { int l, r, lch, rch; int sum, min, max; bool lz_rev; } t[MAXN * 2]; void sgt_push(int idx) { if (t[idx].lz_rev) { t[idx].lz_rev = false; t[idx].sum = -t[idx].sum; int x = t[idx].min; t[idx].min = -t[idx].max; t[idx].max = -x; if (t[idx].lch != -1) { t[t[idx].lch].lz_rev ^= 1; t[t[idx].rch].lz_rev ^= 1; } } } void sgt_update(int idx) { if (t[idx].lch != -1) { sgt_push(t[idx].lch); sgt_push(t[idx].rch); t[idx].sum = t[t[idx].lch].sum + t[t[idx].rch].sum; t[idx].min = min(t[t[idx].lch].min, t[t[idx].rch].min); t[idx].max = max(t[t[idx].lch].max, t[t[idx].rch].max); } } int sgt_querysum(int l, int r, int idx = 0) { sgt_push(idx); if (t[idx].l >= l && t[idx].r <= r) return t[idx].sum; int ret = 0; if (t[t[idx].lch].r >= l) ret += sgt_querysum(l, r, t[idx].lch); if (t[t[idx].rch].l <= r) ret += sgt_querysum(l, r, t[idx].rch); return ret; } int sgt_querymax(int l, int r, int idx = 0) { sgt_push(idx); if (t[idx].l >= l && t[idx].r <= r) return t[idx].max; int ret = -INF; if (t[t[idx].lch].r >= l) ret = max(ret, sgt_querymax(l, r, t[idx].lch)); if (t[t[idx].rch].l <= r) ret = max(ret, sgt_querymax(l, r, t[idx].rch)); return ret; } int sgt_querymin(int l, int r, int idx = 0) { sgt_push(idx); if (t[idx].l >= l && t[idx].r <= r) return t[idx].min; int ret = INF; if (t[t[idx].lch].r >= l) ret = min(ret, sgt_querymin(l, r, t[idx].lch)); if (t[t[idx].rch].l <= r) ret = min(ret, sgt_querymin(l, r, t[idx].rch)); return ret; } void sgt_pointset(int pos, int val, int idx = 0) { sgt_push(idx); if (t[idx].l == pos && t[idx].r == pos) { t[idx].sum = t[idx].min = t[idx].max = val; } else { if (pos <= t[t[idx].lch].r) sgt_pointset(pos, val, t[idx].lch); else sgt_pointset(pos, val, t[idx].rch); sgt_update(idx); } } void sgt_intvrev(int l, int r, int idx = 0) { sgt_push(idx); if (t[idx].l >= l && t[idx].r <= r) { t[idx].lz_rev ^= 1; } else { if (t[t[idx].lch].r >= l) sgt_intvrev(l, r, t[idx].lch); if (t[t[idx].rch].l <= r) sgt_intvrev(l, r, t[idx].rch); } sgt_update(idx); } int sgt_seq[MAXN]; int sgt_build(int l, int r) { static int epoch = -1; int idx = ++epoch; t[idx].l = l; t[idx].r = r; t[idx].lz_rev = false; if (l == r) { t[idx].lch = t[idx].rch = -1; t[idx].sum = t[idx].min = t[idx].max = sgt_seq[l]; } else { int m = (l + r) >> 1; t[idx].lch = sgt_build(l, m); t[idx].rch = sgt_build(m + 1, r); sgt_update(idx); } return idx; } typedef std::pair<int, int> edge; #define dest first #define len second typedef std::vector<edge> edgelist; int n; edgelist e[MAXN]; std::pair<int, int> e_list[MAXN]; int par[MAXN], dep[MAXN], sts[MAXN], pfc[MAXN]; int tvn[MAXN], hct[MAXN]; int par_dist[MAXN]; void hld_dfs1(int r = 0, int p = -1, int d = 0) { par[r] = p; dep[r] = d; sts[r] = 1; pfc[r] = -1; int max_sts = -1; for (edgelist::iterator i = e[r].begin(); i != e[r].end(); ++i) if (i->dest != p) { par_dist[i->dest] = i->len; hld_dfs1(i->dest, r, d + 1); sts[r] += sts[i->dest]; if (max_sts < sts[i->dest]) { max_sts = sts[i->dest]; pfc[r] = i->dest; } } } void hld_dfs2(int r = 0, int t = 0) { static int epoch = -1; tvn[r] = ++epoch; sgt_seq[epoch] = par_dist[r]; hct[r] = t; if (pfc[r] != -1) hld_dfs2(pfc[r], t); for (edgelist::iterator i = e[r].begin(); i != e[r].end(); ++i) if (i->dest != par[r] && i->dest != pfc[r]) hld_dfs2(i->dest, i->dest); } inline void hld_jumpsum(int &u, int &ans) { int v = hct[u]; ans += sgt_querysum(tvn[v], tvn[u]); u = par[v]; } int hld_querysum(int u, int v) { if (dep[u] > dep[v]) swap(u, v); int ans = 0; while (hct[u] != hct[v]) { if (dep[hct[u]] < dep[hct[v]]) hld_jumpsum(v, ans); else hld_jumpsum(u, ans); } if (dep[u] > dep[v]) swap(u, v); if (u != v) ans += sgt_querysum(tvn[u] + 1, tvn[v]); return ans; } inline void hld_jumpmin(int &u, int &ans) { int v = hct[u]; ans = min(ans, sgt_querymin(tvn[v], tvn[u])); u = par[v]; } int hld_querymin(int u, int v) { if (dep[u] > dep[v]) swap(u, v); int ans = INF; while (hct[u] != hct[v]) { if (dep[hct[u]] < dep[hct[v]]) hld_jumpmin(v, ans); else hld_jumpmin(u, ans); } if (dep[u] > dep[v]) swap(u, v); if (u != v) ans = min(ans, sgt_querymin(tvn[u] + 1, tvn[v])); return ans; } inline void hld_jumpmax(int &u, int &ans) { int v = hct[u]; ans = max(ans, sgt_querymax(tvn[v], tvn[u])); u = par[v]; } int hld_querymax(int u, int v) { if (dep[u] > dep[v]) swap(u, v); int ans = -INF; while (hct[u] != hct[v]) { if (dep[hct[u]] < dep[hct[v]]) hld_jumpmax(v, ans); else hld_jumpmax(u, ans); } if (dep[u] > dep[v]) swap(u, v); if (u != v) ans = max(ans, sgt_querymax(tvn[u] + 1, tvn[v])); return ans; } inline void hld_jumprev(int &u) { int v = hct[u]; sgt_intvrev(tvn[v], tvn[u]); u = par[v]; } void hld_pathrev(int u, int v) { if (dep[u] > dep[v]) swap(u, v); while (hct[u] != hct[v]) { if (dep[hct[u]] < dep[hct[v]]) hld_jumprev(v); else hld_jumprev(u); } if (dep[u] > dep[v]) swap(u, v); if (u != v) sgt_intvrev(tvn[u] + 1, tvn[v]); } inline void hld_pointset(int u, int v, int val) { if (par[u] == v) sgt_pointset(tvn[u], val); else sgt_pointset(tvn[v], val); } int main() { scanf("%d", &n); int u, v, w; for (int i = 1; i < n; ++i) { scanf("%d%d%d", &u, &v, &w); e[u].push_back(edge(v, w)); e[v].push_back(edge(u, w)); e_list[i] = std::make_pair(u, v); } hld_dfs1(); hld_dfs2(); sgt_build(0, n - 1); int m; scanf("%d", &m); char op[16]; int a1, a2; do { scanf("%s%d%d", op, &a1, &a2); if (/*!strcmp(op, "C")*/ op[0] == 'C') hld_pointset(e_list[a1].first, e_list[a1].second, a2); else if (/*!strcmp(op, "N")*/ op[0] == 'N') hld_pathrev(a1, a2); else if (/*!strcmp(op, "SUM")*/ op[0] == 'S') printf("%d\n", hld_querysum(a1, a2)); else if (/*!strcmp(op, "MIN")*/ op[1] == 'I') printf("%d\n", hld_querymin(a1, a2)); else /*if (!strcmp(op, "MAX"))*/ printf("%d\n", hld_querymax(a1, a2)); } while (--m); return 0; }