The following issues were found
torch/utils/show_pickle.py
22 issues
Line: 33
Column: 13
return
if obj.state is None:
stream.write(f"{obj.module}.{obj.name}")
printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
return
if not obj.args:
stream.write(f"{obj.module}.{obj.name}()(state=\n")
indent += printer._indent_per_level
stream.write(" " * indent)
Reported by Pylint.
Line: 37
Column: 23
return
if not obj.args:
stream.write(f"{obj.module}.{obj.name}()(state=\n")
indent += printer._indent_per_level
stream.write(" " * indent)
printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
stream.write(")")
return
raise Exception("Need to implement")
Reported by Pylint.
Line: 39
Column: 13
stream.write(f"{obj.module}.{obj.name}()(state=\n")
indent += printer._indent_per_level
stream.write(" " * indent)
printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
stream.write(")")
return
raise Exception("Need to implement")
Reported by Pylint.
Line: 61
Column: 21
return FakeObject(self.module, self.name, args[1:])
class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
def __init__(
self,
file,
*,
catch_invalid_utf8=False,
Reported by Pylint.
Line: 77
Column: 21
def persistent_load(self, pid):
return FakeObject("pers", "obj", (pid,))
dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
# Custom objects in TorchScript are able to return invalid UTF-8 strings
# from their pickle (__getstate__) functions. Install a custom loader
# for strings that catches the decode exception and replaces it with
# a sentinel object.
Reported by Pylint.
Line: 146
Column: 5
# This hack works on every version of Python I've tested.
# I've tested on the following versions:
# 3.7.4
if True:
pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
sys.exit(main(sys.argv))
Reported by Pylint.
Line: 147
Column: 9
# I've tested on the following versions:
# 3.7.4
if True:
pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
sys.exit(main(sys.argv))
Reported by Pylint.
Line: 1
Column: 1
#!/usr/bin/env python3
import sys
import pickle
import struct
import pprint
import zipfile
import fnmatch
from typing import Any, IO, BinaryIO, Union
Reported by Pylint.
Line: 3
Suggestion:
https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b403-import-pickle
#!/usr/bin/env python3
import sys
import pickle
import struct
import pprint
import zipfile
import fnmatch
from typing import Any, IO, BinaryIO, Union
Reported by Bandit.
Line: 11
Column: 1
from typing import Any, IO, BinaryIO, Union
class FakeObject(object):
def __init__(self, module, name, args):
self.module = module
self.name = name
self.args = args
# NOTE: We don't distinguish between state never set and state set to None.
Reported by Pylint.
torch/fx/experimental/unification/more.py
22 issues
Line: 1
Column: 1
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
Reported by Pylint.
Line: 2
Column: 1
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
Reported by Pylint.
Line: 1
Column: 1
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
Reported by Pylint.
Line: 33
Column: 1
#########
def reify_object(o, s):
""" Reify a Python object with a substitution
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
Reported by Pylint.
Line: 33
Column: 1
#########
def reify_object(o, s):
""" Reify a Python object with a substitution
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
Reported by Pylint.
Line: 48
Column: 5
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, '__slots__'):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
Reported by Pylint.
Line: 54
Column: 1
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
Reported by Pylint.
Line: 54
Column: 1
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
Reported by Pylint.
Line: 56
Column: 5
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
return obj
Reported by Pylint.
Line: 63
Column: 1
return obj
def _reify_object_slots(o, s):
attrs = [getattr(o, attr) for attr in o.__slots__]
new_attrs = reify(attrs, s)
if attrs == new_attrs:
return o
else:
Reported by Pylint.
torch/utils/data/datapipes/iter/selecting.py
22 issues
Line: 4
Column: 1
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
from .callable import MapIterDataPipe
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('filter')
Reported by Pylint.
Line: 38
Column: 9
super().__init__(datapipe, fn=filter_fn, fn_args=fn_args, fn_kwargs=fn_kwargs, nesting_level=nesting_level)
def __iter__(self) -> Iterator[T_co]:
res: bool
for data in self.datapipe:
filtered = self._applyFilter(data, self.nesting_level)
if self._isNonEmpty(filtered):
yield filtered
Reported by Pylint.
Line: 1
Column: 1
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
from .callable import MapIterDataPipe
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('filter')
Reported by Pylint.
Line: 2
Column: 1
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
from .callable import MapIterDataPipe
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('filter')
Reported by Pylint.
Line: 6
Column: 1
from .callable import MapIterDataPipe
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('filter')
class FilterIterDataPipe(MapIterDataPipe):
r""" :class:`FilterIterDataPipe`.
Reported by Pylint.
Line: 19
Column: 1
filter_fn: Customized function mapping an element to a boolean.
fn_args: Positional arguments for `filter_fn`
fn_kwargs: Keyword arguments for `filter_fn`
drop_empty_batches: By default, drops batch if it is empty after filtering instead of keeping an empty list
nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
This also accepts -1 as input to apply filtering to the lowest nesting level. It currently doesn't support
argument < -1.
"""
drop_empty_batches: bool
Reported by Pylint.
Line: 20
Column: 1
fn_args: Positional arguments for `filter_fn`
fn_kwargs: Keyword arguments for `filter_fn`
drop_empty_batches: By default, drops batch if it is empty after filtering instead of keeping an empty list
nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
This also accepts -1 as input to apply filtering to the lowest nesting level. It currently doesn't support
argument < -1.
"""
drop_empty_batches: bool
Reported by Pylint.
Line: 21
Column: 1
fn_kwargs: Keyword arguments for `filter_fn`
drop_empty_batches: By default, drops batch if it is empty after filtering instead of keeping an empty list
nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
This also accepts -1 as input to apply filtering to the lowest nesting level. It currently doesn't support
argument < -1.
"""
drop_empty_batches: bool
def __init__(self,
Reported by Pylint.
Line: 26
Column: 5
"""
drop_empty_batches: bool
def __init__(self,
datapipe: IterDataPipe[T_co],
filter_fn: Callable[..., bool],
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
drop_empty_batches: bool = True,
Reported by Pylint.
Line: 35
Column: 1
nesting_level: int = 0,
) -> None:
self.drop_empty_batches = drop_empty_batches
super().__init__(datapipe, fn=filter_fn, fn_args=fn_args, fn_kwargs=fn_kwargs, nesting_level=nesting_level)
def __iter__(self) -> Iterator[T_co]:
res: bool
for data in self.datapipe:
filtered = self._applyFilter(data, self.nesting_level)
Reported by Pylint.
torch/utils/benchmark/utils/timer.py
22 issues
Line: 16
Column: 4
__all__ = ["Timer", "timer", "Language"]
if torch.has_cuda and torch.cuda.is_available():
def timer() -> float:
torch.cuda.synchronize()
return timeit.default_timer()
else:
timer = timeit.default_timer
Reported by Pylint.
Line: 303
Column: 24
with common.set_torch_threads(self._task_spec.num_threads):
# Estimate the block size needed for measurement to be negligible
# compared to the inner loop. This also serves as a warmup.
overhead = torch.tensor([self._timer.timeit(0) for _ in range(5)]).median().item()
number = 1
while True:
time_taken = self._timer.timeit(number)
relative_overhead = overhead / time_taken
if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000:
Reported by Pylint.
Line: 35
Column: 9
stmt: str,
setup: str,
global_setup: str,
timer: Callable[[], float],
globals: Dict[str, Any],
) -> None:
if timer is not timeit.default_timer:
raise NotImplementedError(
"PyTorch was built with CUDA and a GPU is present; however "
Reported by Pylint.
Line: 36
Column: 9
setup: str,
global_setup: str,
timer: Callable[[], float],
globals: Dict[str, Any],
) -> None:
if timer is not timeit.default_timer:
raise NotImplementedError(
"PyTorch was built with CUDA and a GPU is present; however "
"Timer does not yet support GPU measurements. If your "
Reported by Pylint.
Line: 182
Column: 9
stmt: str = "pass",
setup: str = "pass",
global_setup: str = "",
timer: Callable[[], float] = timer,
globals: Optional[Dict[str, Any]] = None,
label: Optional[str] = None,
sub_label: Optional[str] = None,
description: Optional[str] = None,
env: Optional[str] = None,
Reported by Pylint.
Line: 183
Column: 9
setup: str = "pass",
global_setup: str = "",
timer: Callable[[], float] = timer,
globals: Optional[Dict[str, Any]] = None,
label: Optional[str] = None,
sub_label: Optional[str] = None,
description: Optional[str] = None,
env: Optional[str] = None,
num_threads: int = 1,
Reported by Pylint.
Line: 394
Column: 23
def time_hook() -> float:
return self._timer.timeit(number)
def stop_hook(times: List[float]) -> bool:
return True
times = self._threaded_measurement_loop(
number, time_hook, stop_hook,
min_run_time=min_run_time,
Reported by Pylint.
Line: 17
Column: 5
if torch.has_cuda and torch.cuda.is_available():
def timer() -> float:
torch.cuda.synchronize()
return timeit.default_timer()
else:
timer = timeit.default_timer
Reported by Pylint.
Line: 24
Column: 1
timer = timeit.default_timer
class Language(enum.Enum):
PYTHON = 0
CPP = 1
class CPPTimer:
Reported by Pylint.
Line: 29
Column: 1
CPP = 1
class CPPTimer:
def __init__(
self,
stmt: str,
setup: str,
global_setup: str,
Reported by Pylint.
test/package/test_glob_group.py
22 issues
Line: 3
Column: 1
from typing import Iterable
from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
Reported by Pylint.
Line: 4
Column: 1
from typing import Iterable
from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
Reported by Pylint.
Line: 1
Column: 1
from typing import Iterable
from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
Reported by Pylint.
Line: 13
Column: 1
from common import PackageTestCase
class TestGlobGroup(PackageTestCase):
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertTrue(glob.matches(candidate))
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
Reported by Pylint.
Line: 14
Column: 5
class TestGlobGroup(PackageTestCase):
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertTrue(glob.matches(candidate))
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
Reported by Pylint.
Line: 14
Column: 5
class TestGlobGroup(PackageTestCase):
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertTrue(glob.matches(candidate))
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
Reported by Pylint.
Line: 18
Column: 5
for candidate in candidates:
self.assertTrue(glob.matches(candidate))
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertFalse(glob.matches(candidate))
def test_one_star(self):
glob_group = GlobGroup("torch.*")
Reported by Pylint.
Line: 18
Column: 5
for candidate in candidates:
self.assertTrue(glob.matches(candidate))
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertFalse(glob.matches(candidate))
def test_one_star(self):
glob_group = GlobGroup("torch.*")
Reported by Pylint.
Line: 22
Column: 5
for candidate in candidates:
self.assertFalse(glob.matches(candidate))
def test_one_star(self):
glob_group = GlobGroup("torch.*")
self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"])
self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"])
def test_one_star_middle(self):
Reported by Pylint.
Line: 27
Column: 5
self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"])
self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"])
def test_one_star_middle(self):
glob_group = GlobGroup("foo.*.bar")
self.assertMatchesGlob(glob_group, ["foo.q.bar", "foo.foo.bar"])
self.assertNotMatchesGlob(
glob_group,
[
Reported by Pylint.
torch/backends/mkldnn/__init__.py
22 issues
Line: 8
Column: 12
def is_available():
r"""Returns whether PyTorch is built with MKL-DNN support."""
return torch._C.has_mkldnn
def set_flags(_enabled):
orig_flags = (torch._C._get_mkldnn_enabled(),)
torch._C._set_mkldnn_enabled(_enabled)
return orig_flags
Reported by Pylint.
Line: 11
Column: 19
return torch._C.has_mkldnn
def set_flags(_enabled):
orig_flags = (torch._C._get_mkldnn_enabled(),)
torch._C._set_mkldnn_enabled(_enabled)
return orig_flags
@contextmanager
def flags(enabled=False):
Reported by Pylint.
Line: 11
Column: 19
return torch._C.has_mkldnn
def set_flags(_enabled):
orig_flags = (torch._C._get_mkldnn_enabled(),)
torch._C._set_mkldnn_enabled(_enabled)
return orig_flags
@contextmanager
def flags(enabled=False):
Reported by Pylint.
Line: 12
Column: 5
def set_flags(_enabled):
orig_flags = (torch._C._get_mkldnn_enabled(),)
torch._C._set_mkldnn_enabled(_enabled)
return orig_flags
@contextmanager
def flags(enabled=False):
with __allow_nonbracketed_mutation():
Reported by Pylint.
Line: 12
Column: 5
def set_flags(_enabled):
orig_flags = (torch._C._get_mkldnn_enabled(),)
torch._C._set_mkldnn_enabled(_enabled)
return orig_flags
@contextmanager
def flags(enabled=False):
with __allow_nonbracketed_mutation():
Reported by Pylint.
Line: 26
Column: 5
set_flags(orig_flags[0])
class MkldnnModule(PropModule):
def __init__(self, m, name):
super(MkldnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
# Cool stuff from torch/backends/cudnn/__init__.py and
Reported by Pylint.
Line: 29
Column: 27
def __init__(self, m, name):
super(MkldnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
# Cool stuff from torch/backends/cudnn/__init__.py and
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
Reported by Pylint.
Line: 29
Column: 57
def __init__(self, m, name):
super(MkldnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
# Cool stuff from torch/backends/cudnn/__init__.py and
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
Reported by Pylint.
Line: 29
Column: 57
def __init__(self, m, name):
super(MkldnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
# Cool stuff from torch/backends/cudnn/__init__.py and
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
Reported by Pylint.
Line: 29
Column: 27
def __init__(self, m, name):
super(MkldnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
# Cool stuff from torch/backends/cudnn/__init__.py and
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
Reported by Pylint.
torch/distributed/elastic/multiprocessing/errors/__init__.py
22 issues
Line: 64
Column: 1
from torch.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
log = get_logger()
Reported by Pylint.
Line: 65
Column: 1
from torch.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
log = get_logger()
JSON = Dict
Reported by Pylint.
Line: 150
Column: 20
def signal_name(self) -> str:
if self.exitcode < 0:
return signal.Signals(-self.exitcode).name
else:
return _NOT_AVAILABLE
def timestamp_isoformat(self):
"""
Reported by Pylint.
Line: 111
Column: 21
try:
with open(self.error_file, "r") as fp:
self.error_file_data = json.load(fp)
log.info(
f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}"
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
Reported by Pylint.
Line: 118
Column: 17
self.error_file_data
)
except Exception:
log.exception(f"Failed to parse reply file: {self.error_file}")
raise
else:
self._set_no_reply_file()
# make up an informative message if not already present
Reported by Pylint.
Line: 234
Column: 20
def format_msg(self, boarder_delim="*", section_delim="="):
title = f" {self.name} FAILED "
root_rank, root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []
width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()):
Reported by Pylint.
Line: 75
Column: 1
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>"
T = TypeVar("T")
@dataclass
class ProcessFailure:
"""
Reported by Pylint.
Line: 109
Column: 52
self.error_file_data = _EMPTY_ERROR_DATA
if os.path.isfile(self.error_file):
try:
with open(self.error_file, "r") as fp:
self.error_file_data = json.load(fp)
log.info(
f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}"
)
self.message, self.timestamp = self._get_error_data(
Reported by Pylint.
Line: 112
Column: 1
with open(self.error_file, "r") as fp:
self.error_file_data = json.load(fp)
log.info(
f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}"
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
except Exception:
Reported by Pylint.
Line: 134
Column: 5
else:
self.message = f"Process failed with exitcode {self.exitcode}"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
message = error_file_data["message"]
if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0))
else:
timestamp = int(message["extraInfo"]["timestamp"])
Reported by Pylint.
torch/distributions/gamma.py
22 issues
Line: 10
Column: 12
def _standard_gamma(concentration):
return torch._standard_gamma(concentration)
class Gamma(ExponentialFamily):
r"""
Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
Reported by Pylint.
Line: 45
Column: 27
def __init__(self, concentration, rate, validate_args=None):
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super(Gamma, self).__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
Reported by Pylint.
Line: 52
Column: 23
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Gamma, _instance)
batch_shape = torch.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Gamma, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
Reported by Pylint.
Line: 59
Column: 36
new._validate_args = self._validate_args
return new
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
return value
Reported by Pylint.
Line: 62
Column: 35
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
return value
def log_prob(self, value):
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
Reported by Pylint.
Line: 66
Column: 17
return value
def log_prob(self, value):
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return (self.concentration * torch.log(self.rate) +
(self.concentration - 1) * torch.log(value) -
self.rate * value - torch.lgamma(self.concentration))
Reported by Pylint.
Line: 69
Column: 38
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return (self.concentration * torch.log(self.rate) +
(self.concentration - 1) * torch.log(value) -
self.rate * value - torch.lgamma(self.concentration))
def entropy(self):
return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
Reported by Pylint.
Line: 70
Column: 44
if self._validate_args:
self._validate_sample(value)
return (self.concentration * torch.log(self.rate) +
(self.concentration - 1) * torch.log(value) -
self.rate * value - torch.lgamma(self.concentration))
def entropy(self):
return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
(1.0 - self.concentration) * torch.digamma(self.concentration))
Reported by Pylint.
Line: 71
Column: 37
self._validate_sample(value)
return (self.concentration * torch.log(self.rate) +
(self.concentration - 1) * torch.log(value) -
self.rate * value - torch.lgamma(self.concentration))
def entropy(self):
return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
(1.0 - self.concentration) * torch.digamma(self.concentration))
Reported by Pylint.
Line: 74
Column: 61
self.rate * value - torch.lgamma(self.concentration))
def entropy(self):
return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
(1.0 - self.concentration) * torch.digamma(self.concentration))
@property
def _natural_params(self):
return (self.concentration - 1, -self.rate)
Reported by Pylint.
torch/distributions/bernoulli.py
22 issues
Line: 45
Column: 27
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
Reported by Pylint.
Line: 52
Column: 23
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Bernoulli, _instance)
batch_shape = torch.Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
Reported by Pylint.
Line: 75
Column: 5
return self.probs * (1 - self.probs)
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
@lazy_property
def probs(self):
return logits_to_probs(self.logits, is_binary=True)
Reported by Pylint.
Line: 79
Column: 5
return probs_to_logits(self.probs, is_binary=True)
@lazy_property
def probs(self):
return logits_to_probs(self.logits, is_binary=True)
@property
def param_shape(self):
return self._param.size()
Reported by Pylint.
Line: 86
Column: 35
def param_shape(self):
return self._param.size()
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.bernoulli(self.probs.expand(shape))
def log_prob(self, value):
Reported by Pylint.
Line: 89
Column: 20
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.bernoulli(self.probs.expand(shape))
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
Reported by Pylint.
Line: 101
Column: 18
return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none')
def enumerate_support(self, expand=True):
values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
values = values.view((-1,) + (1,) * len(self._batch_shape))
if expand:
values = values.expand((-1,) + self._batch_shape)
return values
Reported by Pylint.
Line: 109
Column: 17
@property
def _natural_params(self):
return (torch.log(self.probs / (1 - self.probs)), )
def _log_normalizer(self, x):
return torch.log(1 + torch.exp(x))
Reported by Pylint.
Line: 112
Column: 30
return (torch.log(self.probs / (1 - self.probs)), )
def _log_normalizer(self, x):
return torch.log(1 + torch.exp(x))
Reported by Pylint.
Line: 112
Column: 16
return (torch.log(self.probs / (1 - self.probs)), )
def _log_normalizer(self, x):
return torch.log(1 + torch.exp(x))
Reported by Pylint.
torch/distributions/one_hot_categorical.py
22 issues
Line: 48
Column: 23
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(OneHotCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new._categorical = self._categorical.expand(batch_shape)
super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
Reported by Pylint.
Line: 81
Column: 35
def param_shape(self):
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
num_events = self._categorical._num_events
indices = self._categorical.sample(sample_shape)
return torch.nn.functional.one_hot(indices, num_events).to(probs)
Reported by Pylint.
Line: 82
Column: 24
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
num_events = self._categorical._num_events
indices = self._categorical.sample(sample_shape)
return torch.nn.functional.one_hot(indices, num_events).to(probs)
Reported by Pylint.
Line: 99
Column: 18
def enumerate_support(self, expand=True):
n = self.event_shape[0]
values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = values.expand((n,) + self.batch_shape + (n,))
return values
Reported by Pylint.
Line: 115
Column: 36
"""
has_rsample = True
def rsample(self, sample_shape=torch.Size()):
samples = self.sample(sample_shape)
probs = self._categorical.probs # cached via @lazy_property
return samples + (probs - probs.detach())
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.distribution import Distribution
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
:attr:`logits`.
Samples are one-hot coded vectors of size ``probs.size(-1)``.
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.distribution import Distribution
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
:attr:`logits`.
Samples are one-hot coded vectors of size ``probs.size(-1)``.
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.distribution import Distribution
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
:attr:`logits`.
Samples are one-hot coded vectors of size ``probs.size(-1)``.
Reported by Pylint.
Line: 49
Column: 9
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(OneHotCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new._categorical = self._categorical.expand(batch_shape)
super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
Reported by Pylint.
Line: 51
Column: 9
batch_shape = torch.Size(batch_shape)
new._categorical = self._categorical.expand(batch_shape)
super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
Reported by Pylint.