Skip to content

Commit

Permalink
update API docs on tfutils
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Oct 28, 2017
1 parent 403815b commit 6f55416
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 29 deletions.
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,18 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable',
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'predictor_factory',
'get_predictors',
'RandomCropAroundBox',
'GaussianDeform',
'dump_chkpt_vars',
'VisualQA',
'huber_loss',
'DumpTensor',
'StagingInputWrapper',
'StepTensorPrinter'
'StepTensorPrinter',

'guided_relu', 'saliency_map', 'get_scalar_var',
'prediction_incorrect', 'huber_loss',
]:
return True
if name in ['get_data', 'size', 'reset_state']:
Expand Down
30 changes: 23 additions & 7 deletions docs/modules/tfutils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ tensorpack.tfutils.gradproc module
:undoc-members:
:show-inheritance:

tensorpack.tfutils.tower module
------------------------------------

.. automodule:: tensorpack.tfutils.tower
:members:
:undoc-members:
:show-inheritance:

tensorpack.tfutils.scope_utils module
--------------------------------------

Expand All @@ -47,6 +55,14 @@ tensorpack.tfutils.sesscreate module
:undoc-members:
:show-inheritance:

tensorpack.tfutils.sessinit module
------------------------------------

.. automodule:: tensorpack.tfutils.sessinit
:members:
:undoc-members:
:show-inheritance:

tensorpack.tfutils.summary module
---------------------------------

Expand Down Expand Up @@ -79,11 +95,11 @@ tensorpack.tfutils.varreplace module
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: tensorpack.tfutils
:members:
:undoc-members:
:show-inheritance:
Other functions in tensorpack.tfutils module
---------------------------------------------

.. automethod:: tensorpack.tfutils.get_default_sess_config
.. automethod:: tensorpack.tfutils.get_global_step_var
.. automethod:: tensorpack.tfutils.get_global_step_value
.. automethod:: tensorpack.tfutils.argscope
.. automethod:: tensorpack.tfutils.get_arg_scope
6 changes: 1 addition & 5 deletions tensorpack/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name

__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']


@six.add_metaclass(ABCMeta)
Expand Down Expand Up @@ -206,10 +206,6 @@ def __str__(self):
return type(self).__name__


# back-compat. in case someone write something in triggerable
Triggerable = Callback


class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
Expand Down
5 changes: 4 additions & 1 deletion tensorpack/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def _global_import(name):


_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utils', 'registry']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if not module_name.startswith('_'):
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
_global_import(module_name)
2 changes: 2 additions & 0 deletions tensorpack/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@

from .registry import layer_register # noqa
from .utils import VariableHolder, rename_get_variable # noqa

__all__ = ['layer_register', 'VariableHolder']
7 changes: 4 additions & 3 deletions tensorpack/tfutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from pkgutil import iter_modules
import os

__all__ = []
from .tower import get_current_tower_context, TowerContext
# don't want to include everything from .tower
__all__ = ['get_current_tower_context', 'TowerContext']


def _global_import(name):
Expand All @@ -21,7 +23,6 @@ def _global_import(name):
'common',
'sessinit',
'argscope',
'tower',
])

_CURR_DIR = os.path.dirname(__file__)
Expand All @@ -36,4 +37,4 @@ def _global_import(name):
_global_import(module_name) # import the content to tfutils.*
__all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions',
'distributed'])
'distributed', 'tower'])
9 changes: 0 additions & 9 deletions tensorpack/tfutils/symbolic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@

# this function exists for backwards-compatibilty
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
Args:
logits: shape [B,C].
label: shape [B].
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)

Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/varreplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..utils.develop import deprecated

__all__ = ['custom_getter_scope', 'replace_get_variable',
__all__ = ['replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables']

Expand Down

0 comments on commit 6f55416

Please sign in to comment.