Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
VectorType vecCast = mlir::cast<VectorType>(lhs.getType());
IntType integralTy =
getSIntNTy(getCIRIntOrFloatBitWidth(vecCast.getElementType()));
VectorType integralVecTy =
VectorType::get(context, integralTy, vecCast.getSize());
VectorType integralVecTy = cir::VectorType::get(integralTy, vecCast.getSize());
return cir::VecCmpOp::create(*this, loc, integralVecTy, kind, lhs, rhs);
}

Expand Down
18 changes: 11 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,9 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
]> {
let summary = "CIR vector type";
let description = [{
The `!cir.vector` type represents a fixed-size, one-dimensional vector.
It takes two parameters: the element type and the number of elements.
The `!cir.vector` type represents a one-dimensional vector.
It takes three parameters: the element type, the number of elements and the
scalability flag (optional, defaults to `false`).

Syntax:

Expand All @@ -444,19 +445,21 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
}];

let parameters = (ins
CIR_VectorElementType:$elementType,
"uint64_t":$size
CIR_VectorElementType:$element_type,
"uint64_t":$size,
OptionalParameter<"bool">:$is_scalable
);

let assemblyFormat = [{
`<` $size `x` $elementType `>`
`<` $size `x` $element_type `>`
}];

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$elementType, "uint64_t":$size
"mlir::Type":$element_type, "uint64_t":$size, CArg<"bool",
"false">:$isScalable
), [{
return $_get(elementType.getContext(), elementType, size);
return $_get(element_type.getContext(), element_type, size, isScalable);
}]>,
];

Expand All @@ -467,6 +470,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
}];

let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
9 changes: 7 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
cir::IntType getUInt32Ty() { return typeCache.uInt32Ty; }
cir::IntType getUInt64Ty() { return typeCache.uInt64Ty; }

cir::FP16Type getFp16Ty() { return typeCache.fP16Ty; }
cir::BF16Type getBfloat6Ty() { return typeCache.bFloat16Ty; }
cir::SingleType getSingleTy() { return typeCache.floatTy; }
cir::DoubleType getDoubleTy() { return typeCache.doubleTy; }

cir::ConstantOp getConstInt(mlir::Location loc, llvm::APSInt intVal);

cir::ConstantOp getConstInt(mlir::Location loc, llvm::APInt intVal);
Expand Down Expand Up @@ -628,8 +633,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
createVecShuffle(mlir::Location loc, mlir::Value vec1, mlir::Value vec2,
llvm::ArrayRef<mlir::Attribute> maskAttrs) {
auto vecType = mlir::cast<cir::VectorType>(vec1.getType());
auto resultTy = cir::VectorType::get(getContext(), vecType.getElementType(),
maskAttrs.size());
auto resultTy =
cir::VectorType::get(vecType.getElementType(), maskAttrs.size());
return cir::VecShuffleOp::create(*this, loc, resultTy, vec1, vec2,
getArrayAttr(maskAttrs));
}
Expand Down
48 changes: 40 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//

#include "CIRGenBuilder.h"
#include "CIRGenFunction.h"
#include "clang/CIR/MissingFeatures.h"

Expand All @@ -30,6 +31,27 @@ using namespace clang;
using namespace clang::CIRGen;
using namespace llvm;

template <typename... Operands>
static mlir::Value emitIntrinsicCallOp(CIRGenBuilderTy &builder,
mlir::Location loc, const StringRef str,
const mlir::Type &resTy,
Operands &&...op) {
return cir::LLVMIntrinsicCallOp::create(builder, loc,
builder.getStringAttr(str), resTy,
std::forward<Operands>(op)...)
.getResult();
}

// Generate vscale * scalingFactor
static mlir::Value genVscaleTimesFactor(mlir::Location loc,
CIRGenBuilderTy builder,
mlir::Type cirTy,
int32_t scalingFactor) {
auto vscale = emitIntrinsicCallOp(builder, loc, "vscale", cirTy);
return builder.createNUWAMul(loc, vscale,
builder.getUInt64(scalingFactor, loc));
}

std::optional<mlir::Value>
CIRGenFunction::emitAArch64SVEBuiltinExpr(unsigned builtinID,
const CallExpr *expr) {
Expand All @@ -47,6 +69,8 @@ CIRGenFunction::emitAArch64SVEBuiltinExpr(unsigned builtinID,
default:
return std::nullopt;

mlir::Location loc = getLoc(expr->getExprLoc());

case SVE::BI__builtin_sve_svreinterpret_b:
case SVE::BI__builtin_sve_svreinterpret_c:
case SVE::BI__builtin_sve_svpsel_lane_b8:
Expand Down Expand Up @@ -101,18 +125,26 @@ CIRGenFunction::emitAArch64SVEBuiltinExpr(unsigned builtinID,
case SVE::BI__builtin_sve_svdupq_n_s32:
case SVE::BI__builtin_sve_svpfalse_b:
case SVE::BI__builtin_sve_svpfalse_c:
case SVE::BI__builtin_sve_svlen_bf16:
case SVE::BI__builtin_sve_svlen_f16:
case SVE::BI__builtin_sve_svlen_f32:
case SVE::BI__builtin_sve_svlen_f64:
case SVE::BI__builtin_sve_svlen_s8:
case SVE::BI__builtin_sve_svlen_s16:
case SVE::BI__builtin_sve_svlen_s32:
case SVE::BI__builtin_sve_svlen_s64:
cgm.errorNYI(expr->getSourceRange(),
std::string("unimplemented AArch64 builtin call: ") +
getContext().BuiltinInfo.getName(builtinID));
return mlir::Value{};
case SVE::BI__builtin_sve_svlen_u8:
case SVE::BI__builtin_sve_svlen_s8:
return genVscaleTimesFactor(loc, builder, convertType(expr->getType()), 16);
case SVE::BI__builtin_sve_svlen_u16:
case SVE::BI__builtin_sve_svlen_s16:
case SVE::BI__builtin_sve_svlen_f16:
case SVE::BI__builtin_sve_svlen_bf16:
return genVscaleTimesFactor(loc, builder, convertType(expr->getType()), 8);
case SVE::BI__builtin_sve_svlen_u32:
case SVE::BI__builtin_sve_svlen_s32:
case SVE::BI__builtin_sve_svlen_f32:
return genVscaleTimesFactor(loc, builder, convertType(expr->getType()), 4);
case SVE::BI__builtin_sve_svlen_u64:
case SVE::BI__builtin_sve_svlen_s64:
case SVE::BI__builtin_sve_svlen_f64:
return genVscaleTimesFactor(loc, builder, convertType(expr->getType()), 2);
case SVE::BI__builtin_sve_svtbl2_u8:
case SVE::BI__builtin_sve_svtbl2_s8:
case SVE::BI__builtin_sve_svtbl2_u16:
Expand Down
52 changes: 52 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "clang/AST/GlobalDecl.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

#include <cassert>

Expand Down Expand Up @@ -320,6 +321,57 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
cir::IntType::get(&getMLIRContext(), astContext.getTypeSize(ty),
/*isSigned=*/true);
break;

// SVE types
case BuiltinType::SveInt8:
resultType =
cir::VectorType::get(builder.getSInt8Ty(), 16, /*isScalable=*/true);
break;
case BuiltinType::SveUint8:
resultType =
cir::VectorType::get(builder.getUInt8Ty(), 16, /*isScalable=*/true);
break;
case BuiltinType::SveInt16:
resultType =
cir::VectorType::get(builder.getSInt16Ty(), 8, /*isScalable=*/true);
break;
case BuiltinType::SveUint16:
resultType =
cir::VectorType::get(builder.getUInt16Ty(), 8, /*isScalable=*/true);
break;
case BuiltinType::SveFloat16:
resultType = cir::VectorType::get(builder.getFp16Ty(), 8,
/*isScalable=*/true);
break;
case BuiltinType::SveBFloat16:
resultType = cir::VectorType::get(builder.getFp16Ty(), 8,
/*isScalable=*/true);
break;
case BuiltinType::SveInt32:
resultType =
cir::VectorType::get(builder.getSInt32Ty(), 4, /*isScalable=*/true);
break;
case BuiltinType::SveUint32:
resultType =
cir::VectorType::get(builder.getUInt32Ty(), 4, /*isScalable=*/true);
break;
case BuiltinType::SveFloat32:
resultType = cir::VectorType::get(builder.getSingleTy(), 4,
/*isScalable=*/true);
break;
case BuiltinType::SveInt64:
resultType =
cir::VectorType::get(builder.getSInt64Ty(), 2, /*isScalable=*/true);
break;
case BuiltinType::SveUint64:
resultType =
cir::VectorType::get(builder.getUInt64Ty(), 2, /*isScalable=*/true);
break;
case BuiltinType::SveFloat64:
resultType = cir::VectorType::get(builder.getDoubleTy(), 2,
/*isScalable=*/true);
break;

// Unsigned integral types.
case BuiltinType::Char8:
case BuiltinType::Char16:
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,

mlir::LogicalResult cir::VectorType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type elementType, uint64_t size) {
mlir::Type elementType, uint64_t size, bool scalable) {
if (size == 0)
return emitError() << "the number of vector elements must be non-zero";
return success();
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2910,7 +2910,7 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
});
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
const mlir::Type ty = converter.convertType(type.getElementType());
return mlir::VectorType::get(type.getSize(), ty);
return mlir::VectorType::get(type.getSize(), ty, {type.getIsScalable()});
});
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(), 1,
Expand Down
Loading
Loading