Skip to content

[DOCS][NFC] Fix doc formatting problems #4195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 64 additions & 63 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,12 @@ def num_programs(axis, _builder=None):

@builtin
def arange(start, end, _builder=None):
f"""
start = _constexpr_to_value(start)
end = _constexpr_to_value(end)
return semantic.arange(start, end, _builder)


arange.__doc__ = f"""
Returns contiguous values within the half-open interval :code:`[start,
end)`. :code:`end - start` must be less than or equal to
:code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}`
Expand All @@ -1200,10 +1205,7 @@ def arange(start, end, _builder=None):
:param end: End of the interval. Must be a power of two greater than
:code:`start`.
:type end: int32
"""
start = _constexpr_to_value(start)
end = _constexpr_to_value(end)
return semantic.arange(start, end, _builder)
"""


def _shape_check_impl(shape):
Expand Down Expand Up @@ -1582,9 +1584,8 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
(3) If `pointer` is a block pointer defined by `make_block_ptr`, a
tensor is loaded. In this case:

- `mask` and `other` must be None, and
- `boundary_check` and `padding_option` can be specified to control
the behavior of out-of-bound access.
- `mask` and `other` must be `None`, and
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.

:param pointer: Pointer to the data to be loaded
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
Expand All @@ -1599,7 +1600,7 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see
[cache operator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators) for more details.
`cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional
:param volatile: changes volatile option in NVIDIA PTX
Expand Down Expand Up @@ -1680,7 +1681,7 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for
cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt"
stands for cache write-through, see [cache operator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators) for more details.
stands for cache write-through, see `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
"""
Expand Down Expand Up @@ -2227,7 +2228,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals
.. highlight:: python
.. code-block:: python

tl.static_print(f"{BLOCK_SIZE=}")
tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}")
'''
pass

Expand Down Expand Up @@ -2359,66 +2360,66 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
cost you anything if you don't use it.

Example using
[PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html)
`PTX <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html>`_
assembly:

.. highlight:: python
.. code-block:: python

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor

# For each (a,b) in zip(a,b), perform the following:
# - Let ai be `a` converted to int32.
# - Let af be `a` converted to float.
# - Let m be the max of ai and b.
# - Return ai and mi.
# Do the above 4 elements at a time.
(c, d) = tl.inline_asm_elementwise(
asm="""
{
// Unpack `a` into `ai`.
.reg .b8 tmp<4>;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
cvt.u32.u8 $0, tmp0;
cvt.u32.u8 $1, tmp1;
cvt.u32.u8 $2, tmp2;
cvt.u32.u8 $3, tmp3;
}
// Convert `ai` to float.
cvt.rn.f32.s32 $4, $0;
cvt.rn.f32.s32 $5, $1;
cvt.rn.f32.s32 $6, $2;
cvt.rn.f32.s32 $7, $3;
// Take max of `ai` and `b`.
max.f32 $4, $4, $9;
max.f32 $5, $5, $10;
max.f32 $6, $6, $11;
max.f32 $7, $7, $12;
""",
constraints=(
# 8 output registers, namely
# $0=ai0, $1=ai1, $2=ai2, $3=ai3,
# $4=m0, $5=m1, $6=m2, $7=m3.
"=r,=r,=r,=r,=r,=r,=r,=r,"
# 5 input registers, namely
# $8=ai,
# $9=b0, $10=b1, $11=b2, $12=b3.
# The four elements from `a` are all packed into one register.
"r,r,r,r,r"),
args=[a, b],
dtype=(tl.int32, tl.float32),
is_pure=True,
pack=4,
)
tl.store(C + tl.arange(0, BLOCK), c)
tl.store(D + tl.arange(0, BLOCK), d)
@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor

# For each (a,b) in zip(a,b), perform the following:
# - Let ai be `a` converted to int32.
# - Let af be `a` converted to float.
# - Let m be the max of ai and b.
# - Return ai and mi.
# Do the above 4 elements at a time.
(c, d) = tl.inline_asm_elementwise(
asm="""
{
// Unpack `a` into `ai`.
.reg .b8 tmp<4>;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
cvt.u32.u8 $0, tmp0;
cvt.u32.u8 $1, tmp1;
cvt.u32.u8 $2, tmp2;
cvt.u32.u8 $3, tmp3;
}
// Convert `ai` to float.
cvt.rn.f32.s32 $4, $0;
cvt.rn.f32.s32 $5, $1;
cvt.rn.f32.s32 $6, $2;
cvt.rn.f32.s32 $7, $3;
// Take max of `ai` and `b`.
max.f32 $4, $4, $9;
max.f32 $5, $5, $10;
max.f32 $6, $6, $11;
max.f32 $7, $7, $12;
""",
constraints=(
# 8 output registers, namely
# $0=ai0, $1=ai1, $2=ai2, $3=ai3,
# $4=m0, $5=m1, $6=m2, $7=m3.
"=r,=r,=r,=r,=r,=r,=r,=r,"
# 5 input registers, namely
# $8=ai,
# $9=b0, $10=b1, $11=b2, $12=b3.
# The four elements from `a` are all packed into one register.
"r,r,r,r,r"),
args=[a, b],
dtype=(tl.int32, tl.float32),
is_pure=True,
pack=4,
)
tl.store(C + tl.arange(0, BLOCK), c)
tl.store(D + tl.arange(0, BLOCK), d)

:param asm: assembly to run. Must match target's assembly format.
:param constraints: asm constraints in
[LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string)
`LLVM format <https://llvm.org/docs/LangRef.html#inline-asm-constraint-string>`_
:param args: the input tensors, whose values are passed to the asm block
:param dtype: the element type(s) of the returned tensor(s)
:param is_pure: if true, the compiler assumes the asm block has no side-effects
Expand Down
Loading