You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
1.9 KiB
67 lines
1.9 KiB
'''
|
|
written by Lorenz Muller
|
|
'''
|
|
|
|
import numpy as np
|
|
from time import time
|
|
|
|
|
|
def loadData(path='./', valfrac=0.1, delimiter='::', seed=1234,
|
|
transpose=False):
|
|
'''
|
|
loads ml-1m data
|
|
|
|
:param path: path to the ratings file
|
|
:param valfrac: fraction of data to use for validation
|
|
:param delimiter: delimiter used in data file
|
|
:param seed: random seed for validation splitting
|
|
:param transpose: flag to transpose output matrices (swapping users with movies)
|
|
:return: train ratings (n_u, n_m), valid ratings (n_u, n_m)
|
|
'''
|
|
np.random.seed(seed)
|
|
|
|
tic = time()
|
|
print('reading data...')
|
|
data = np.loadtxt(path, skiprows=0, delimiter=delimiter).astype('int32')
|
|
print('data read in', time() - tic, 'seconds')
|
|
|
|
n_u = np.unique(data[:, 0]).shape[0] # number of users
|
|
n_m = np.unique(data[:, 1]).shape[0] # number of movies
|
|
n_r = data.shape[0] # number of ratings
|
|
|
|
# these dictionaries define a mapping from user/movie id to to user/movie number (contiguous from zero)
|
|
udict = {}
|
|
for i, u in enumerate(np.unique(data[:, 0]).tolist()):
|
|
udict[u] = i
|
|
mdict = {}
|
|
for i, m in enumerate(np.unique(data[:, 1]).tolist()):
|
|
mdict[m] = i
|
|
|
|
# shuffle indices
|
|
idx = np.arange(n_r)
|
|
np.random.shuffle(idx)
|
|
|
|
trainRatings = np.zeros((n_u, n_m), dtype='float32')
|
|
validRatings = np.zeros((n_u, n_m), dtype='float32')
|
|
|
|
for i in range(n_r):
|
|
u_id = data[idx[i], 0]
|
|
m_id = data[idx[i], 1]
|
|
r = data[idx[i], 2]
|
|
|
|
# the first few ratings of the shuffled data array are validation data
|
|
if i <= valfrac * n_r:
|
|
validRatings[udict[u_id], mdict[m_id]] = int(r)
|
|
# the rest are training data
|
|
else:
|
|
trainRatings[udict[u_id], mdict[m_id]] = int(r)
|
|
|
|
if transpose:
|
|
trainRatings = trainRatings.T
|
|
validRatings = validRatings.T
|
|
|
|
print('loaded dense data matrix')
|
|
|
|
return trainRatings, validRatings
|
|
|