Links

考虑两种限制:

对于第一种限制,我们设 F(u)F(u)uu 点在 SSTT 的最短路径上出现的次数,那第一个限制就是要满足 F(a)+F(b)=F(t)F(a) + F(b) = F(t)

这个 FF 可以用两次最短路求出。

对于第二种限制,我们可以考虑枚举每个 uu 算答案时来满足,即,你枚举一个 uu ,那 vvFF 值你可以算出来,那你用一个桶装 bitset 即可找出满足第一个限制的 vv ,对于第二个限制,你再分别做两遍拓扑求出与每个点 uu 在同一条最短路上的点用 bitset 存下来,与一下即可解决。

感觉思维难度并不大,关键是要对限制的合理转化,思维瓶颈在把第一个限制换个形式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <bits/stdc++.h>
using namespace std;
template <typename T> inline void read(T &a){
T w = 1; a = 0;
char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') w = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) a = (a * 10) + (ch - '0');
a *= w;
}
#define int long long
template <typename T> inline void ckmax(T &a, T b){a = a > b ? a : b;}
template <typename T> inline void ckmin(T &a, T b){a = a < b ? a : b;}
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define mii map<int, int>
#define pii pair<int, int>
#define vi vector<int>
#define si set<int>
#define ins insert
#define LL long long
#define era erase
#define Debug(x) cout << #x << " = " << x << endl
#define For(i,l,r) for (int i = l; i <= r; ++i)
#define foR(i,l,r) for (int i = l; i >= r; --i)
const int N = 5e4 + 10;
struct edge {
int to, nxt;
LL dis;
} e[N << 1];
int head[N], cnt;
inline void add (int u, int v, LL w) {
e[++cnt].to = v; e[cnt].nxt = head[u]; head[u] = cnt; e[cnt].dis = w;
}
int n, m, S, T;
struct Node {
int id; LL dis;
Node (int Id = 0, LL Dis = 0) {
id = Id, dis = Dis;
}
bool operator < (const Node &a) const {
return dis > a.dis;
}
};

LL dis[2][N], f[2][N];
bool vis[N];
inline void Dij (int op) {
priority_queue <Node> q;
memset (dis[op], 0x3f, sizeof dis[op]);
memset (vis, 0, sizeof vis);
if (op == 0) q.push(Node(S, 0)), dis[op][S] = 0, f[op][S] = 1;
else q.push(Node(T, 0)), dis[op][T] = 0, f[op][T] = 1;
while (!q.empty()) {
Node tmp = q.top(); q.pop();
if (vis[tmp.id]) continue;
vis[tmp.id] = 1;
int u = tmp.id;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (dis[op][v] == dis[op][u] + e[i].dis) f[op][v] += f[op][u];
if (dis[op][v] > dis[op][u] + e[i].dis) {
dis[op][v] = dis[op][u] + e[i].dis;
f[op][v] = f[op][u];
q.push(Node(v, dis[op][v]));
}
}
}
}

int shortest, F[N];
map<LL, bitset <N>> t;
bitset <N> g[2][N];
bool InShort (int u) { return (dis[0][u] + dis[1][u] == shortest); }
int du[N];

void Topu (int op) {
static queue <int> q; while (!q.empty()) q.pop();

For (u, 1, n)
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (dis[op][u] + e[i].dis + dis[op ^ 1][v] == shortest) du[v]++;
}
For (i, 1, n) {
g[op][i].set(), g[op][i][0] = g[op][i][i] = 0;
if (!du[i]) q.push(i);
}
// For (i, 1, n) cout << g[op][i] << endl;
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (dis[op][u] + e[i].dis + dis[op ^ 1][v] == shortest) {
du[v]--;
g[op][v] &= g[op][u];
if (!du[v]) q.push(v);
}
}
}

}

signed main() {
read(n), read(m), read(S), read(T);
For (i, 1, m) {
int u, v; LL w; read(u), read(v), read(w);
add(u, v, w), add(v, u, w);
}
Dij(0), Dij(1);
shortest = dis[0][T];
if (!f[0][T]) return printf ("%lld\n", 1ll * n * (n - 1) / 2), 0;
For (i, 1, n) {if (InShort(i)) {
F[i] = f[0][i] * f[1][i];
}
t[F[i]].set(i);
}
Topu(0), Topu(1);
LL Ans = 0;
For (i, 1, n) Ans += (t[F[T] - F[i]] & g[0][i] & g[1][i]).count();
printf ("%lld\n", Ans >> 1);
}