The following issues were found
torch/csrc/jit/codegen/cuda/executor_utils.cpp
9 issues
Line: 325
Column: 29
CWE codes:
807
20
Suggestion:
Check environment variables carefully before using them
"--std=c++14", compute.c_str(), "-default-device"};
#endif
const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA");
// int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0;
if (disable_fma && atoi(disable_fma)) {
#ifdef __HIP_PLATFORM_HCC__
TORCH_WARN_ONCE(
"PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring");
Reported by FlawFinder.
Line: 336
Column: 33
CWE codes:
807
20
Suggestion:
Check environment variables carefully before using them
#endif
}
const char* ptxas_opt_level = getenv("PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t jit_opt_level;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<CUjit_option> options;
Reported by FlawFinder.
Line: 415
Column: 28
CWE codes:
807
20
Suggestion:
Check environment variables carefully before using them
// TODO: We do go through different code path, should investigate whether this
// has an impact on generated binary.
const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN");
#ifndef __HIP_PLATFORM_HCC__
if (prefix_env) {
#if CUDA_VERSION >= 11010
TORCH_CHECK(
!compile_to_sass,
Reported by FlawFinder.
Line: 327
Column: 22
CWE codes:
190
Suggestion:
If source untrusted, check both minimum and maximum, even if the input had no minus sign (large numbers can roll over into negative number; consider saving to an unsigned value if that is intended)
const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA");
// int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0;
if (disable_fma && atoi(disable_fma)) {
#ifdef __HIP_PLATFORM_HCC__
TORCH_WARN_ONCE(
"PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring");
#else
args.push_back("--fmad=false");
Reported by FlawFinder.
Line: 346
Column: 15
CWE codes:
190
Suggestion:
If source untrusted, check both minimum and maximum, even if the input had no minus sign (large numbers can roll over into negative number; consider saving to an unsigned value if that is intended)
std::vector<void*> option_vals;
if (ptxas_opt_level) {
int val = atoi(ptxas_opt_level);
if (val <= 4 && val >= 0) {
jit_opt_level = static_cast<uint32_t>(val);
options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
option_vals.emplace_back(&jit_opt_level);
} else {
Reported by FlawFinder.
Line: 189
Column: 63
CWE codes:
126
Suggestion:
This function is often discouraged by most C++ coding standards in favor of its safer alternatives provided since C++14. Consider using a form of this function that checks the second iterator before potentially overflowing it
for (const auto i : c10::irange(inputs.size())) {
const IValue& arg = inputs[i];
const Val* param = fusion->inputs()[i];
mismatch = !validateKernelArg(arg, param, device, msg) || mismatch;
}
TORCH_INTERNAL_ASSERT(
!mismatch, "Found one or more invalid arguments: ", msg.str());
}
Reported by FlawFinder.
Line: 192
Column: 8
CWE codes:
126
Suggestion:
This function is often discouraged by most C++ coding standards in favor of its safer alternatives provided since C++14. Consider using a form of this function that checks the second iterator before potentially overflowing it
mismatch = !validateKernelArg(arg, param, device, msg) || mismatch;
}
TORCH_INTERNAL_ASSERT(
!mismatch, "Found one or more invalid arguments: ", msg.str());
}
void validateKernelOutputs(
Fusion* fusion,
const std::vector<at::Tensor>& outputs,
Reported by FlawFinder.
Line: 214
Column: 63
CWE codes:
126
Suggestion:
This function is often discouraged by most C++ coding standards in favor of its safer alternatives provided since C++14. Consider using a form of this function that checks the second iterator before potentially overflowing it
for (const auto i : c10::irange(outputs.size())) {
const at::Tensor& arg = outputs[i];
const Val* param = fusion->outputs()[i];
mismatch = !validateKernelArg(arg, param, device, msg) || mismatch;
}
TORCH_INTERNAL_ASSERT(
!mismatch, "Found one or more invalid arguments: ", msg.str());
}
Reported by FlawFinder.
Line: 217
Column: 8
CWE codes:
126
Suggestion:
This function is often discouraged by most C++ coding standards in favor of its safer alternatives provided since C++14. Consider using a form of this function that checks the second iterator before potentially overflowing it
mismatch = !validateKernelArg(arg, param, device, msg) || mismatch;
}
TORCH_INTERNAL_ASSERT(
!mismatch, "Found one or more invalid arguments: ", msg.str());
}
StatefulExpressionEvaluator statefulBindInputs(
const at::ArrayRef<IValue>& aten_inputs,
Fusion* fusion,
Reported by FlawFinder.
torch/distributed/rpc/rref_proxy.py
9 issues
Line: 3
Column: 1
from functools import partial
from . import functions
import torch
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
Reported by Pylint.
Line: 6
Column: 1
from . import functions
import torch
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
Reported by Pylint.
Line: 18
Column: 17
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
# Since rref._get_type can potentially issue an RPC, it should respect the
# passed in timeout here.
rref_type = rref._get_type(timeout=timeout)
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
Reported by Pylint.
Line: 23
Column: 20
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
Reported by Pylint.
Line: 1
Column: 1
from functools import partial
from . import functions
import torch
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
Reported by Pylint.
Line: 5
Column: 1
from . import functions
import torch
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
Reported by Pylint.
Line: 39
Column: 1
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
Reported by Pylint.
Line: 39
Column: 1
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
Reported by Pylint.
Line: 23
Column: 20
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
Reported by Pylint.
test/package/package_a/test_all_leaf_modules_tracer.py
9 issues
Line: 1
Column: 1
from torch.fx import Tracer
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 5
Column: 30
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 5
Column: 33
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 1
Column: 1
from torch.fx import Tracer
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 4
Column: 1
from torch.fx import Tracer
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 4
Column: 1
from torch.fx import Tracer
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 5
Column: 5
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 5
Column: 5
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
Line: 5
Column: 5
class TestAllLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return True
Reported by Pylint.
torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py
9 issues
Line: 12
Column: 1
from typing import Optional, Tuple, cast
import urllib3.exceptions # type: ignore[import]
from etcd import Client as EtcdClient # type: ignore[import]
from etcd import (
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
Reported by Pylint.
Line: 13
Column: 1
import urllib3.exceptions # type: ignore[import]
from etcd import Client as EtcdClient # type: ignore[import]
from etcd import (
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
EtcdResult,
Reported by Pylint.
Line: 22
Column: 1
)
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
Reported by Pylint.
Line: 23
Column: 1
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
Reported by Pylint.
Line: 24
Column: 1
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
"""Represents an etcd-based rendezvous backend.
Reported by Pylint.
Line: 25
Column: 1
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
"""Represents an etcd-based rendezvous backend.
Reported by Pylint.
Line: 1
Column: 1
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
from base64 import b64decode, b64encode
from typing import Optional, Tuple, cast
Reported by Pylint.
Line: 131
Column: 5
tmp = *self._decode_state(result), True
return tmp
def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
base64_state = result.value.encode()
try:
state = b64decode(base64_state)
except binascii.Error as exc:
Reported by Pylint.
Line: 154
Column: 8
# The communication protocol
protocol = params.get("protocol", "http").strip().lower()
if protocol != "http" and protocol != "https":
raise ValueError("The protocol must be HTTP or HTTPS.")
# The SSL client certificate
ssl_cert = params.get("ssl_cert")
if ssl_cert:
Reported by Pylint.
test/test_determination.py
9 issues
Line: 5
Column: 1
import unittest
import run_test
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
Reported by Pylint.
Line: 1
Column: 1
import os
import unittest
import run_test
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
Reported by Pylint.
Line: 5
Column: 1
import unittest
import run_test
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
Reported by Pylint.
Line: 8
Column: 1
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
class DeterminationTest(unittest.TestCase):
# Test determination on a subset of tests
Reported by Pylint.
Line: 8
Column: 1
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
class DeterminationTest(unittest.TestCase):
# Test determination on a subset of tests
Reported by Pylint.
Line: 8
Column: 1
from torch.testing._internal.common_utils import run_tests
class DummyOptions(object):
verbose = False
class DeterminationTest(unittest.TestCase):
# Test determination on a subset of tests
Reported by Pylint.
Line: 12
Column: 1
verbose = False
class DeterminationTest(unittest.TestCase):
# Test determination on a subset of tests
TESTS = [
"test_nn",
"test_jit_profiling",
"test_jit",
Reported by Pylint.
Line: 29
Column: 5
]
@classmethod
def determined_tests(cls, changed_files):
changed_files = [os.path.normpath(path) for path in changed_files]
return [
test
for test in cls.TESTS
if run_test.determine_target(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions())
Reported by Pylint.
Line: 34
Column: 1
return [
test
for test in cls.TESTS
if run_test.determine_target(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions())
]
def test_config_change_only(self):
"""CI configs trigger all tests"""
self.assertEqual(
Reported by Pylint.
torch/distributions/logistic_normal.py
9 issues
Line: 7
Column: 1
from torch.distributions.transforms import StickBreakingTransform
class LogisticNormal(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`StickBreakingTransform` such that::
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.transforms import StickBreakingTransform
class LogisticNormal(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`StickBreakingTransform` such that::
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.transforms import StickBreakingTransform
class LogisticNormal(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`StickBreakingTransform` such that::
Reported by Pylint.
Line: 7
Column: 1
from torch.distributions.transforms import StickBreakingTransform
class LogisticNormal(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`StickBreakingTransform` such that::
Reported by Pylint.
Line: 1
Column: 1
from torch.distributions import constraints
from torch.distributions.normal import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import StickBreakingTransform
class LogisticNormal(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
Reported by Pylint.
Line: 37
Column: 9
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super(LogisticNormal, self).__init__(base_dist,
StickBreakingTransform(),
validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogisticNormal, _instance)
Reported by Pylint.
Line: 43
Column: 16
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogisticNormal, _instance)
return super(LogisticNormal, self).expand(batch_shape, _instance=new)
@property
def loc(self):
return self.base_dist.base_dist.loc
Reported by Pylint.
Line: 46
Column: 5
return super(LogisticNormal, self).expand(batch_shape, _instance=new)
@property
def loc(self):
return self.base_dist.base_dist.loc
@property
def scale(self):
return self.base_dist.base_dist.scale
Reported by Pylint.
Line: 50
Column: 5
return self.base_dist.base_dist.loc
@property
def scale(self):
return self.base_dist.base_dist.scale
Reported by Pylint.
torch/csrc/jit/tensorexpr/scripts/bisect.py
9 issues
Line: 9
Suggestion:
https://bandit.readthedocs.io/en/latest/plugins/b602_subprocess_popen_with_shell_equals_true.html
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
capture_output=True,
encoding="utf-8",
)
print(p.stdout)
f = "INTERNAL ASSERT FAILED"
Reported by Bandit.
Line: 68
Column: 5
if __name__ == "__main__":
bisect()
Reported by Pylint.
Line: 7
Column: 9
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
capture_output=True,
encoding="utf-8",
)
Reported by Pylint.
Line: 1
Suggestion:
https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b404-import-subprocess
import subprocess
import click
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
Reported by Bandit.
Line: 1
Column: 1
import subprocess
import click
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
Reported by Pylint.
Line: 5
Column: 1
import click
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
capture_output=True,
Reported by Pylint.
Line: 7
Column: 5
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
capture_output=True,
encoding="utf-8",
)
Reported by Pylint.
Line: 14
Column: 5
encoding="utf-8",
)
print(p.stdout)
f = "INTERNAL ASSERT FAILED"
if f in p.stdout or f in p.stderr:
print("skip")
return -1
if p.returncode == 0:
print("good")
Reported by Pylint.
Line: 27
Column: 1
@click.command()
@click.option("--cmd")
def bisect(cmd):
last_good = 0
first_bad = 10000
skips = set()
# Test if there are any unskipped commits in (last_good, first_bad)
Reported by Pylint.
torch/distributed/elastic/utils/api.py
9 issues
Line: 35
Column: 17
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
Reported by Pylint.
Line: 41
Column: 9
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
raise RuntimeError("Failed to create a socket")
class macros:
Reported by Pylint.
Line: 1
Column: 1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
Reported by Pylint.
Line: 30
Column: 1
return value
def get_socket_with_port() -> socket.socket:
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
Reported by Pylint.
Line: 36
Column: 9
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
Reported by Pylint.
Line: 41
Column: 9
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
raise RuntimeError("Failed to create a socket")
class macros:
Reported by Pylint.
Line: 46
Column: 1
raise RuntimeError("Failed to create a socket")
class macros:
"""
Defines simple macros for caffe2.distributed.launch cmd args substitution
"""
local_rank = "${local_rank}"
Reported by Pylint.
Line: 46
Column: 1
raise RuntimeError("Failed to create a socket")
class macros:
"""
Defines simple macros for caffe2.distributed.launch cmd args substitution
"""
local_rank = "${local_rank}"
Reported by Pylint.
Line: 54
Column: 5
local_rank = "${local_rank}"
@staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]:
args_sub = []
for arg in args:
if isinstance(arg, str):
sub = Template(arg).safe_substitute(local_rank=local_rank)
args_sub.append(sub)
Reported by Pylint.
test/onnx/test_models_onnxruntime.py
9 issues
Line: 2
Column: 1
import unittest
import onnxruntime # noqa: F401
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
Reported by Pylint.
Line: 6
Column: 1
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
Reported by Pylint.
Line: 2
Column: 1
import unittest
import onnxruntime # noqa: F401
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
Reported by Pylint.
Line: 1
Column: 1
import unittest
import onnxruntime # noqa: F401
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
Reported by Pylint.
Line: 6
Column: 1
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
Reported by Pylint.
Line: 9
Column: 1
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
Reported by Pylint.
Line: 9
Column: 1
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
Reported by Pylint.
Line: 9
Column: 1
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
Reported by Pylint.
Line: 34
Column: 1
# model tests for scripting with new JIT APIs and shape inference
TestModels_new_jit_API = type(str("TestModels_new_jit_API"),
(unittest.TestCase,),
dict(TestModels.__dict__,
exportTest=exportTest,
is_script_test_enabled=True,
onnx_shape_inference=True))
Reported by Pylint.
test/test_complex.py
9 issues
Line: 1
Column: 1
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_utils import TestCase, run_tests
devices = (torch.device('cpu'), torch.device('cuda:0'))
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
Reported by Pylint.
Line: 2
Column: 1
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_utils import TestCase, run_tests
devices = (torch.device('cpu'), torch.device('cuda:0'))
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
Reported by Pylint.
Line: 3
Column: 1
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_utils import TestCase, run_tests
devices = (torch.device('cpu'), torch.device('cuda:0'))
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
Reported by Pylint.
Line: 1
Column: 1
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_utils import TestCase, run_tests
devices = (torch.device('cpu'), torch.device('cuda:0'))
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
Reported by Pylint.
Line: 7
Column: 1
devices = (torch.device('cpu'), torch.device('cuda:0'))
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
# test that the complex float tensor has expected values and
# there's no garbage value in the resultant list
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
Reported by Pylint.
Line: 9
Column: 5
class TestComplexTensor(TestCase):
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_to_list(self, device, dtype):
# test that the complex float tensor has expected values and
# there's no garbage value in the resultant list
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
@dtypes(torch.float32, torch.float64)
Reported by Pylint.
Line: 12
Column: 1
def test_to_list(self, device, dtype):
# test that the complex float tensor has expected values and
# there's no garbage value in the resultant list
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
@dtypes(torch.float32, torch.float64)
def test_dtype_inference(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/36834
default_dtype = torch.get_default_dtype()
Reported by Pylint.
Line: 15
Column: 5
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
@dtypes(torch.float32, torch.float64)
def test_dtype_inference(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/36834
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
x = torch.tensor([3., 3. + 5.j], device=device)
torch.set_default_dtype(default_dtype)
Reported by Pylint.
Line: 19
Column: 9
# issue: https://github.com/pytorch/pytorch/issues/36834
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
x = torch.tensor([3., 3. + 5.j], device=device)
torch.set_default_dtype(default_dtype)
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
instantiate_device_type_tests(TestComplexTensor, globals())
Reported by Pylint.