swdata/neighbor_per_round.py
2018-03-11 17:27:59 +08:00

110 lines
2.7 KiB
Python

import json
from matplotlib import pyplot as plt
from island.match import Match
from island.matches import Matches
from numpy import mean, std
import numpy as np
matches = Matches('wos-data-new')
max_round = 17
survivals = {}
with open('survivals.json', 'r') as f:
survivals = json.load(f)
neighbors = {}
cmean = []
dmean = []
cstd = []
dstd = []
for i in range(len(matches.data)):
m = matches.data[i]
n = {}
for r in m.query('neighbor', 'create').raw_data:
if r['a'] in n:
n[r['a']].append(r['b'])
else:
n[r['a']] = [r['b']]
if r['b'] in n:
n[r['b']].append(r['a'])
else:
n[r['b']] = [r['a']]
neighbors[matches.names[i]] = n
for i in range(max_round):
cneigh = []
dneigh = []
for j in range(len(matches.data)):
rows = matches.data[j].query('action', 'done').where(lambda x: x['rno']==i+1).raw_data
calced = set()
for row in rows:
if row['a'] not in calced:
nn = 0
for k in neighbors[matches.names[j]][row['a']]:
if k in survivals[matches.names[j]][str(i+1)]:
nn += 1
if row['act_a'] == 'C':
cneigh.append(nn)
else:
dneigh.append(nn)
calced.add(row['a'])
if row['b'] not in calced:
nn = 0
for k in neighbors[matches.names[j]][row['b']]:
if k in survivals[matches.names[j]][str(i+1)]:
nn += 1
if row['act_b'] == 'C':
cneigh.append(nn)
else:
dneigh.append(nn)
calced.add(row['b'])
if cneigh:
cm = mean(cneigh)
cs = std(cneigh)
else:
cm = 0
cs = 0
cmean.append(cm)
cstd.append(cs)
if dneigh:
dm = mean(dneigh)
ds = std(dneigh)
else:
dm = 0
ds = 0
dmean.append(dm)
dstd.append(ds)
fig, ax = plt.subplots()
index = np.arange(17)
bar_width = 0.35
opacity = 0.4
error_config = {'ecolor': '0.3', 'capsize': 4}
rects1 = ax.bar(index, cmean, bar_width,
alpha=opacity, color='b',
yerr=cstd, error_kw=error_config,
label='C')
rects2 = ax.bar(index + bar_width, dmean, bar_width,
alpha=opacity, color='r',
yerr=dstd, error_kw=error_config,
label='D')
# ax.set_xlabel('Group')
# ax.set_ylabel('Scores')
# ax.set_title('Scores by group and gender')
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(index+1)
ax.legend()
fig.tight_layout()
plt.show()