parent
32612d773c
commit
3d09cd2dbe
@ -1,157 +0,0 @@
|
||||
# encoding: UTF-8
|
||||
# Copyright 2017 Google.com
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from code import my_txtutils as txt
|
||||
|
||||
TST_TXTSIZE = 10000
|
||||
TST_SEQLEN = 10
|
||||
TST_BATCHSIZE = 13
|
||||
TST_EPOCHS = 5
|
||||
|
||||
|
||||
class RnnMinibatchSequencerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# generate text of consecutive items
|
||||
self.data = list(range(TST_TXTSIZE))
|
||||
|
||||
@staticmethod
|
||||
def check_seq_batch(batch1, batch2):
|
||||
nb_errors = 0
|
||||
for i in range(TST_BATCHSIZE):
|
||||
ok = batch1[i, -1] + 1 == batch2[i, 0]
|
||||
nb_errors += 0 if ok else 1
|
||||
return nb_errors
|
||||
|
||||
def test_sequences(self):
|
||||
for x, y, epoch in txt.rnn_minibatch_sequencer(self.data, TST_BATCHSIZE, TST_SEQLEN, TST_EPOCHS):
|
||||
for i in range(TST_BATCHSIZE):
|
||||
self.assertListEqual(x[i, 1:].tolist(), y[i, :-1].tolist(),
|
||||
msg="y sequences must be equal to x sequences shifted by -1")
|
||||
|
||||
def test_batches(self):
|
||||
start = True
|
||||
prev_x = np.zeros([TST_BATCHSIZE, TST_SEQLEN], np.int32)
|
||||
prev_y = np.zeros([TST_BATCHSIZE, TST_SEQLEN], np.int32)
|
||||
nb_errors = 0
|
||||
nb_batches = 0
|
||||
for x, y, epoch in txt.rnn_minibatch_sequencer(self.data, TST_BATCHSIZE, TST_SEQLEN, TST_EPOCHS):
|
||||
if not start:
|
||||
nb_errors += self.check_seq_batch(prev_x, x)
|
||||
nb_errors += self.check_seq_batch(prev_y, y)
|
||||
prev_x = x
|
||||
prev_y = y
|
||||
start = False
|
||||
nb_batches += 1
|
||||
self.assertLessEqual(nb_errors, 2 * TST_EPOCHS,
|
||||
msg="Sequences should be correctly continued, even between epochs. Only "
|
||||
"one sequence is allowed to not continue from one epoch to the next.")
|
||||
self.assertLess(TST_TXTSIZE - (nb_batches * TST_BATCHSIZE * TST_SEQLEN),
|
||||
TST_BATCHSIZE * TST_SEQLEN * TST_EPOCHS,
|
||||
msg="Text ignored at the end of an epoch must be smaller than one batch of sequences")
|
||||
|
||||
|
||||
class EncodingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_text_known_chars = \
|
||||
"PRIDE AND PREJUDICE" \
|
||||
"\n" \
|
||||
"By Jane Austen" \
|
||||
"\n" \
|
||||
"\n" \
|
||||
"\n" \
|
||||
"Chapter 1" \
|
||||
"\n" \
|
||||
"\n" \
|
||||
"It is a truth universally acknowledged, that a single man in possession " \
|
||||
"of a good fortune, must be in want of a wife." \
|
||||
"\n\n" \
|
||||
"However little known the feelings or views of such a man may be on his " \
|
||||
"first entering a neighbourhood, this truth is so well fixed in the minds " \
|
||||
"of the surrounding families, that he is considered the rightful property " \
|
||||
"of some one or other of their daughters." \
|
||||
"\n\n" \
|
||||
"\"My dear Mr. Bennet,\" said his lady to him one day, \"have you heard that " \
|
||||
"Netherfield Park is let at last?\"" \
|
||||
"\n\n" \
|
||||
"Mr. Bennet replied that he had not." \
|
||||
"\n\n" \
|
||||
"\"But it is,\" returned she; \"for Mrs. Long has just been here, and she " \
|
||||
"told me all about it.\"" \
|
||||
"\n\n" \
|
||||
"Mr. Bennet made no answer." \
|
||||
"\n\n" \
|
||||
"\"Do you not want to know who has taken it?\" cried his wife impatiently." \
|
||||
"\n\n" \
|
||||
"\"_You_ want to tell me, and I have no objection to hearing it.\"" \
|
||||
"\n\n" \
|
||||
"This was invitation enough." \
|
||||
"\n\n" \
|
||||
"\"Why, my dear, you must know, Mrs. Long says that Netherfield is taken " \
|
||||
"by a young man of large fortune from the north of England; that he came " \
|
||||
"down on Monday in a chaise and four to see the place, and was so much " \
|
||||
"delighted with it, that he agreed with Mr. Morris immediately; that he " \
|
||||
"is to take possession before Michaelmas, and some of his servants are to " \
|
||||
"be in the house by the end of next week.\"" \
|
||||
"\n\n" \
|
||||
"\"What is his name?\"" \
|
||||
"\n\n" \
|
||||
"\"Bingley.\"" \
|
||||
"\n\n" \
|
||||
"Testing punctuation: !\"#$%&\'()*+,-./0123456789:;<=>?@[\\]^_`{|}~" \
|
||||
"\n" \
|
||||
"Tab\x09Tab\x09Tab\x09Tab" \
|
||||
"\n"
|
||||
self.test_text_unknown_char = "Unknown char: \x0C" # the unknown char 'new page'
|
||||
|
||||
def test_encoding(self):
|
||||
encoded = txt.encode_text(self.test_text_known_chars)
|
||||
decoded = txt.decode_to_text(encoded)
|
||||
self.assertEqual(self.test_text_known_chars, decoded,
|
||||
msg="On a sequence of supported characters, encoding, "
|
||||
"then decoding should yield the original string.")
|
||||
|
||||
def test_unknown_encoding(self):
|
||||
encoded = txt.encode_text(self.test_text_unknown_char)
|
||||
decoded = txt.decode_to_text(encoded)
|
||||
original_fix = self.test_text_unknown_char[:-1] + chr(0)
|
||||
self.assertEqual(original_fix, decoded,
|
||||
msg="The last character of the test sequence is an unsupported "
|
||||
"character and should be encoded and decoded as 0.")
|
||||
|
||||
|
||||
class TxtProgressTest(unittest.TestCase):
|
||||
def test_progress_indicator(self):
|
||||
print("If the printed output of this test is incorrect, the test will fail. No need to check visually.", end='')
|
||||
test_cases = (50, 51, 49, 1, 2, 3, 1000, 333, 101)
|
||||
p = txt.Progress(100)
|
||||
for maxi in test_cases:
|
||||
m, cent = self.check_progress_indicator(p, maxi)
|
||||
self.assertEqual(m, maxi, msg="Incorrect number of steps.")
|
||||
self.assertEqual(cent, 100, msg="Incorrect number of steps.")
|
||||
|
||||
@staticmethod
|
||||
def check_progress_indicator(p, maxi):
|
||||
p._Progress__print_header()
|
||||
progress = p._Progress__start_progress(maxi)
|
||||
total = 0
|
||||
n = 0
|
||||
for k in progress():
|
||||
total += k
|
||||
n += 1
|
||||
return n, total
|
Loading…
Reference in new issue