fix chunks per user function bug caused by numpy version of array_split

This commit is contained in:
René Knaebel 2017-07-16 18:49:14 +02:00
parent 844494eca9
commit d33c9f44ec
2 changed files with 11 additions and 5 deletions

View File

@ -71,9 +71,13 @@ def get_user_chunks(user_flow, window=10):
# domains.pop(-1)
# flows.pop(-1)
# return domains, flows
result = []
chunk_size = (len(user_flow) // window)
last_inrange = chunk_size * window
return np.split(user_flow.head(last_inrange), chunk_size) if chunk_size else []
for i in range(chunk_size):
result.append(user_flow.iloc[i * window:(i + 1) * window])
if result and len(result[-1]) != window:
result.pop()
return result
def get_domain_features(domain, vocab: dict, max_length=40):
@ -153,7 +157,9 @@ def create_dataset_from_lists(chunks, vocab, max_len):
logger.info(" compute domain features")
domain_features = []
for ds in tqdm(map(lambda f: f.domain, chunks)):
assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}"
# TODO: fix this correctly
# assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}"
if not ds: continue
domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
domain_features = np.concatenate(domain_features, 0)
logger.info(" compute flow features")
@ -161,7 +167,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
logger.info(" select hits")
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, chunks)), axis=1)
logger.info(" select names")
names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)), axis=1)
names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)))
logger.info(" select servers")
servers = np.max(np.stack(map(lambda f: f.serverLabel, chunks)), axis=1)
logger.info(" select trusted hits")

View File

@ -68,7 +68,7 @@ class Hyperband:
r = self.max_iter * self.eta ** (-s)
# n random configurations
T = [self.get_params() for i in range(n)]
T = [self.get_params() for _ in range(n)]
for i in range((s + 1) - int(skip_last)): # changed from s + 1